diff --git a/src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc b/src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dbc484a917abce606fd373a1d1333cef4561a4ae
Binary files /dev/null and b/src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/docs/source/conf.py b/src_code_for_reproducibility/docs/source/conf.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7512678928b6b7580c812cd62d1c22df9945ba
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/conf.py
@@ -0,0 +1,48 @@
+# Configuration file for the Sphinx documentation builder.
+import os
+import sys
+sys.path.insert(0, os.path.abspath('../..'))
+
+# -- Project information -----------------------------------------------------
+project = 'llm_negotiation'
+copyright = '2023, Your Name'
+author = 'Your Name'
+
+# -- General configuration ---------------------------------------------------
+extensions = [
+ 'sphinx.ext.autodoc',
+ 'sphinx.ext.viewcode',
+ 'sphinx.ext.napoleon',
+ 'sphinx.ext.autosummary',
+ 'sphinx.ext.intersphinx',
+ 'sphinx.ext.mathjax',
+ 'sphinxcontrib.mermaid',
+ 'sphinx_rtd_theme',
+]
+
+templates_path = ['_templates']
+exclude_patterns = []
+
+# -- Options for HTML output -------------------------------------------------
+html_theme = 'sphinx_rtd_theme'
+html_static_path = ['_static']
+
+# -- Napoleon settings -------------------------------------------------------
+napoleon_google_docstring = True
+napoleon_numpy_docstring = False
+napoleon_include_init_with_doc = True
+napoleon_include_private_with_doc = False
+napoleon_include_special_with_doc = True
+napoleon_use_admonition_for_examples = False
+napoleon_use_admonition_for_notes = False
+napoleon_use_admonition_for_references = False
+napoleon_use_ivar = False
+napoleon_use_param = True
+napoleon_use_rtype = True
+napoleon_preprocess_types = False
+napoleon_type_aliases = None
+napoleon_attr_annotations = True
+
+# -- Path setup --------------------------------------------------------------
+# Make sure the project's modules can be found by Sphinx
+sys.path.insert(0, os.path.abspath('../../src'))
\ No newline at end of file
diff --git a/src_code_for_reproducibility/docs/source/contributing.rst b/src_code_for_reproducibility/docs/source/contributing.rst
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src_code_for_reproducibility/docs/source/environments/diplomacy.rst b/src_code_for_reproducibility/docs/source/environments/diplomacy.rst
new file mode 100644
index 0000000000000000000000000000000000000000..c2121d08ecd6e5e13691c05624d22ddadef1f0c3
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/environments/diplomacy.rst
@@ -0,0 +1,459 @@
+=================
+Diplomacy
+=================
+
+The Diplomacy environment provides a multi-agent negotiation interface for the classic board game Diplomacy,
+based on DeepMind's implementation. This document describes the API for interacting with the Diplomacy environment
+and its associated agent handler.
+
+Overview
+--------
+
+Diplomacy is a strategic board game set in Europe before World War I, where players control one of seven European powers
+and negotiate with each other to gain control of supply centers. The game is played in turns, with each turn consisting
+of movement phases, retreat phases, and build phases.
+
+Our implementation adapts DeepMind's Diplomacy code to the Multi-Agent Negotiation Environment standard, allowing it
+to be used with LLM agents through a text-based interface.
+
+Game Rules
+----------
+
+### Game Board and Powers
+
+Diplomacy is played on a map of Europe divided into provinces. The game features seven Great Powers that players can control:
+
+- England (blue)
+- France (light blue)
+- Germany (black)
+- Italy (green)
+- Austria-Hungary (red)
+- Russia (white)
+- Turkey (yellow)
+
+Each power begins with three supply centers (except Russia, which starts with four) and an equal number of units.
+
+### Units and Movement
+
+There are two types of units in Diplomacy:
+- **Armies (A)**: Can move to adjacent land provinces or be convoyed across water by fleets
+- **Fleets (F)**: Can move to adjacent coastal provinces and sea regions
+
+During movement phases, each unit can execute one of these orders:
+- **Hold**: The unit remains in its current province (e.g., "A PAR H")
+ - Format: [Unit Type] [Province] H
+ - Example: "A PAR H" means "Army in Paris holds its position"
+
+- **Move**: The unit attempts to move to an adjacent province (e.g., "A PAR - BUR")
+ - Format: [Unit Type] [Current Province] - [Destination Province]
+ - Example: "A PAR - BUR" means "Army in Paris moves to Burgundy"
+ - Example: "F BRE - ENG" means "Fleet in Brest moves to the English Channel"
+
+- **Support**: The unit supports another unit's move or hold (e.g., "A PAR S A MAR - BUR")
+ - Format for supporting a move: [Unit Type] [Province] S [Unit Type] [Province] - [Destination]
+ - Format for supporting a hold: [Unit Type] [Province] S [Unit Type] [Province]
+ - Example: "A PAR S A MAR - BUR" means "Army in Paris supports the Army in Marseille's move to Burgundy"
+ - Example: "F LON S F NTH" means "Fleet in London supports the Fleet in North Sea holding its position"
+
+- **Convoy**: A fleet can convoy an army across water (e.g., "F ENG C A LON - BRE")
+ - Format: [Fleet] [Sea Province] C [Army] [Coastal Province] - [Coastal Province]
+ - Example: "F ENG C A LON - BRE" means "Fleet in English Channel convoys the Army in London to Brest"
+
+All orders are executed simultaneously, and conflicts are resolved based on strength (number of supporting units).
+
+### Common Province Abbreviations
+
+Diplomacy uses three-letter abbreviations for provinces. Some common ones include:
+- **PAR**: Paris
+- **LON**: London
+- **BER**: Berlin
+- **MUN**: Munich
+- **BUR**: Burgundy
+- **MAR**: Marseilles
+- **BRE**: Brest
+- **ENG**: English Channel
+- **NTH**: North Sea
+- **VIE**: Vienna
+- **ROM**: Rome
+- **VEN**: Venice
+- **MOW**: Moscow
+- **CON**: Constantinople
+
+### Example: Movement and Conflicts
+
+For example, if France orders "A PAR - BUR" and Germany orders "A MUN - BUR", neither move succeeds as they have equal strength. However, if France also orders "A MAR S A PAR - BUR", then the French army from Paris would successfully move to Burgundy with strength of 2 against Germany's strength of 1.
+
+### Turn Structure
+
+A game year consists of five phases:
+1. **Spring Movement**: All powers submit orders for their units
+2. **Spring Retreat**: Units dislodged in the movement phase must retreat or be disbanded
+3. **Fall Movement**: Another round of movement orders
+4. **Fall Retreat**: Retreat orders for dislodged units
+5. **Winter Adjustment**: Powers gain or lose units based on the number of supply centers they control
+
+### Supply Centers and Building
+
+Supply centers (marked on the map) are key to victory. When a power occupies a supply center during a Fall turn, they gain control of it. During the Winter Adjustment phase:
+- If you control more supply centers than you have units, you can build new units in your home supply centers
+- If you control fewer supply centers than you have units, you must remove excess units
+
+### Example: Building and Removing Units
+
+If France controls 5 supply centers but only has 4 units, during the Winter phase they can build one new unit in an unoccupied home supply center (Paris, Marseilles, or Brest). Conversely, if France controls only 3 supply centers but has 4 units, they must remove one unit of their choice.
+
+### Negotiation
+
+A critical component of Diplomacy is the negotiation between players. Before submitting orders, players can communicate freely to form alliances, coordinate attacks, or mislead opponents. These negotiations are not binding, and betrayal is a common strategy.
+
+### Example: Alliance and Betrayal
+
+England and France might agree to an alliance against Germany, with England promising to support France's move into Belgium. However, England could secretly order their fleet to move into Belgium themselves or support a German move instead.
+
+### Victory Conditions
+
+The game ends when one power controls 18 or more supply centers (majority of the 34 total centers), or when players agree to a draw. In tournament settings, games may also end after a predetermined number of game years.
+
+DiplomacyEnv
+------------
+
+The ``DiplomacyEnv`` class provides an interface to the Diplomacy game environment that follows the Multi-Agent
+Negotiation Environment standard.
+
+.. code-block:: python
+
+ class DiplomacyEnv:
+ """
+ Multi-Agent Negotiation Environment for Diplomacy, adapting Deepmind's implementation
+ to the MarlEnvironment standard.
+ """
+ def __init__(self,
+ initial_state: Optional[DiplomacyState] = None,
+ max_turns: int = 100,
+ points_per_supply_centre: bool = True,
+ forced_draw_probability: float = 0.0,
+ min_years_forced_draw: int = 35):
+ """Initialize the Diplomacy environment.
+
+ Args:
+ initial_state: Initial DiplomacyState (optional)
+ max_turns: Maximum number of turns in the game
+ points_per_supply_centre: Whether to award points per supply center in case of a draw
+ forced_draw_probability: Probability of forcing a draw after min_years_forced_draw
+ min_years_forced_draw: Minimum years before considering a forced draw
+ """
+ # ...
+
+ def reset(self):
+ """Reset the environment to an initial state and return the initial observation.
+
+ Returns:
+ observation (dict): A dictionary where keys are agent identifiers and values are observations.
+ Each observation contains:
+ - board_state: Current state of the board
+ - current_season: Current season in the game
+ - player_index: Index of the player's power
+ - possible_actions: List of possible actions in DeepMind's format
+ - human_readable_actions: List of human-readable action descriptions
+ - supply_centers: List of supply centers owned by the player
+ - units: List of units owned by the player
+ - year: Current year in the game
+ """
+ # ...
+
+ def step(self, actions):
+ """Take a step in the environment using the provided actions.
+
+ Args:
+ actions (dict): A dictionary where keys are agent identifiers and values are actions.
+ Actions can be:
+ - List of integer actions in DeepMind's format
+ - List of string actions in text format (e.g., "A MUN - BER")
+
+ Returns:
+ observations (dict): A dictionary where keys are agent identifiers and values are observations.
+ Each observation has the same structure as in reset().
+ done (bool): Whether the episode has ended.
+ info (dict): Additional information about the environment, including:
+ - turn: Current turn number
+ - returns: Game returns if the game is done, otherwise None
+ - waiting_for: List of agents that still need to provide actions (if not all actions are provided)
+ """
+ # ...
+
+ def get_log_info(self):
+ """Get additional information about the environment for logging.
+
+ Returns:
+ log_info (dict): Information about the environment required to log the game, including:
+ - power_names: List of power names
+ - game_history: History of the game
+ - current_turn: Current turn number
+ - current_season: Current season name
+ - supply_centers: Dictionary mapping power names to supply center counts
+ """
+ # ...
+
+ def render(self):
+ """Render the current state of the environment.
+
+ Displays a visualization of the current game state.
+ """
+ # ...
+
+ def close(self):
+ """Perform any necessary cleanup."""
+ # ...
+
+
+Key Implementation Details
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The ``DiplomacyEnv`` class implements several key features:
+
+1. **Multi-Agent Support**: The environment tracks multiple agents (powers) and manages their interactions.
+
+2. **Turn-Based Gameplay**: The environment enforces the turn structure of Diplomacy, including different phases.
+
+3. **Action Processing**: The environment can handle actions in both text format and DeepMind's integer format.
+
+4. **Observation Generation**: The environment generates detailed observations for each agent, including board state, supply centers, and possible actions.
+
+5. **Game Termination**: The environment tracks game termination conditions, including supply center victory and maximum turn limits.
+
+Observation Structure
+~~~~~~~~~~~~~~~~~~~~
+
+Each agent receives an observation dictionary with the following structure:
+
+.. code-block:: python
+
+ {
+ "board_state": np.ndarray, # Board state representation
+ "current_season": int, # Season index (0-4)
+ "player_index": int, # Index of the player's power (0-6)
+ "possible_actions": [int], # List of possible actions in DeepMind's format
+ "human_readable_actions": [str], # List of human-readable action descriptions
+ "supply_centers": [str], # List of supply centers owned by the player
+ "units": [dict], # List of units owned by the player
+ "year": int # Current year in the game
+ }
+
+Action Structure
+~~~~~~~~~~~~~~~
+
+Actions can be provided in two formats:
+
+1. **Text Format**: String actions like ``"A MUN - BER"`` or ``"F NTH C A LON - BEL"``.
+
+2. **Integer Format**: Lists of integers corresponding to DeepMind's action representation.
+
+The environment will convert text actions to the internal format as needed.
+
+DiplomacyAgent
+--------------
+
+The ``DiplomacyAgent`` class implements the agent handler interface for Diplomacy, processing observations from the environment and generating actions through an LLM.
+
+.. code-block:: python
+
+ class DiplomacyAgent:
+ """
+ Agent handler for Diplomacy, implementing the AgentState interface
+ for the multi-agent negotiation standard.
+ """
+
+ def __init__(self,
+ power_name: str,
+ use_text_interface: bool = True,
+ system_prompt: Optional[str] = None):
+ """Initialize the Diplomacy agent handler.
+
+ Args:
+ power_name: Name of the power this agent controls
+ use_text_interface: Whether to use text-based interface (vs. structured)
+ system_prompt: Optional system prompt to use for the LLM
+ """
+ # ...
+
+ def step(self, observation_from_env, policy_output=None):
+ """Update the agent state based on the observation and action.
+
+ Args:
+ observation_from_env: The observation from the environment, with structure:
+ - board_state: Current state of the board
+ - current_season: Current season in the game
+ - player_index: Index of the player's power
+ - possible_actions: List of possible actions
+ - human_readable_actions: List of human-readable action descriptions
+ - supply_centers: List of supply centers owned by the player
+ - units: List of units owned by the player
+ - year: Current year in the game
+
+ policy_output: The output of the policy (LLM response), or None for initial prompt
+
+ Returns:
+ policy_id (str): The policy identifier ("llm_policy")
+ policy_input (dict): The input to the policy, with structure:
+ - messages: List of conversation messages in the format:
+ [{"role": "system", "content": "..."},
+ {"role": "user", "content": "..."}]
+ action: The official action to be sent to the environment, or None if not ready
+ done (bool): Whether the LLM action is ready to be sent to the environment
+ info (dict): Additional information about the agent:
+ - valid_action: Whether the extracted action is valid
+ """
+ # ...
+
+ def get_log_info(self):
+ """Get information about the agent required to log a trajectory.
+
+ Returns:
+ log_info (dict): Information about the agent required to log a trajectory:
+ - power_name: Name of the power this agent controls
+ - conversation_history: List of conversation messages
+ - current_action: The current action, if any
+ """
+ # ...
+
+ def render(self):
+ """Render the current state of the agent.
+
+ Displays the agent's current state, including conversation history.
+ """
+ # ...
+
+ def close(self):
+ """Perform any necessary cleanup."""
+ # ...
+
+
+Key Implementation Details
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The ``DiplomacyAgent`` class implements several key features:
+
+1. **LLM Interaction**: The agent generates prompts for an LLM and processes the LLM's responses to extract actions.
+
+2. **Conversation Management**: The agent maintains a conversation history for coherent interactions with the LLM.
+
+3. **Action Validation**: The agent validates extracted actions against the set of possible actions provided by the environment.
+
+4. **Error Handling**: The agent generates clarification prompts when invalid actions are detected.
+
+5. **Text-Based Interface**: The agent formats game state information into human-readable text for the LLM.
+
+Prompt Structure
+~~~~~~~~~~~~~~~
+
+The agent generates prompts that include:
+
+1. **System Prompt**: Instructions and context for the LLM, explaining its role as a Diplomacy player.
+
+2. **Game State Description**: A text description of the current game state, including:
+ - Current year and season
+ - Supply centers owned
+ - Units controlled
+ - Possible actions
+
+3. **Action Request**: Instructions on how to format actions.
+
+Example system prompt:
+
+.. code-block:: text
+
+ You are playing the role of FRANCE in a game of Diplomacy.
+ Your goal is to control as many supply centers as possible.
+ You can negotiate with other players and form alliances, but remember that
+ these alliances are not binding. When you need to submit orders for your units,
+ write them in the correct format, with each order on a new line.
+
+Example game state description:
+
+.. code-block:: text
+
+ Year: 1901, Season: SPRING_MOVES
+ You are playing as FRANCE.
+ You currently control 3 supply centers: PAR, MAR, BRE.
+ Your units are: A PAR, A MAR, F BRE.
+
+ Please provide orders for your units. Here are your possible actions:
+ A PAR - BUR
+ A PAR - GAS
+ A PAR - PIC
+ A PAR H
+ ...
+
+ Submit your orders, one per line, in the format like: "A MUN - BER" or "F NTH C A LON - BEL"
+
+Running Diplomacy Games
+----------------------
+
+To run Diplomacy games with LLM agents, you can use the ``run_batched_matches`` function with the ``DiplomacyEnv`` and ``DiplomacyAgent`` classes:
+
+.. code-block:: python
+
+ from mllm.environments.diplomacy.diplomacy_env import DiplomacyEnv
+ from mllm.environments.diplomacy.diplomacy_agent import DiplomacyAgent
+ from mllm.run_matches import run_batched_matches
+
+ # Create environment and agent handlers
+ env = DiplomacyEnv(max_turns=30)
+
+ agent_handlers = {
+ "AUSTRIA": DiplomacyAgent(power_name="AUSTRIA"),
+ "ENGLAND": DiplomacyAgent(power_name="ENGLAND"),
+ "FRANCE": DiplomacyAgent(power_name="FRANCE"),
+ "GERMANY": DiplomacyAgent(power_name="GERMANY"),
+ "ITALY": DiplomacyAgent(power_name="ITALY"),
+ "RUSSIA": DiplomacyAgent(power_name="RUSSIA"),
+ "TURKEY": DiplomacyAgent(power_name="TURKEY")
+ }
+
+ # Define policy mapping (mapping from policy IDs to actual policy functions)
+ policy_mapping = {
+ "llm_policy": my_llm_policy_function
+ }
+
+ # Run the game
+ game_results = run_batched_matches(
+ envs=[env],
+ agent_handlers_per_env=[agent_handlers],
+ policy_mapping=policy_mapping,
+ max_parallel_matches=1
+ )
+
+ # Process results
+ for result in game_results:
+ print(f"Game finished. Winner: {result['winner']}")
+ print(f"Supply centers: {result['supply_centers']}")
+
+This setup allows you to run Diplomacy games with LLM agents using the Multi-Agent Negotiation Environment standard.
+
+Limitations and Considerations
+-----------------------------
+
+1. **Performance**: Processing observations and actions for seven powers using LLMs can be computationally intensive.
+
+2. **Action Parsing**: Extracting valid actions from LLM outputs may require sophisticated parsing and error handling.
+
+3. **Game Complexity**: Diplomacy is a complex game with many rules and edge cases, which may be challenging for LLMs to fully grasp.
+
+4. **Turn Duration**: Real Diplomacy games include negotiation phases of variable duration, which are not fully captured in this implementation.
+
+5. **Text Formatting**: The quality of LLM interactions depends heavily on the formatting and clarity of text prompts.
+
+Advanced Usage
+------------
+
+For advanced usage, you can customize:
+
+1. **System Prompts**: Modify agent behavior by providing custom system prompts.
+
+2. **Observation Processing**: Extend the observation processing to include additional information.
+
+3. **Action Parsing**: Implement more sophisticated action parsing for complex orders.
+
+4. **Visualization**: Add custom visualization methods to the environment's render function.
+
+5. **Logging**: Extend the logging capabilities to capture additional information about the game state.
\ No newline at end of file
diff --git a/src_code_for_reproducibility/docs/source/environments/ipd.rst b/src_code_for_reproducibility/docs/source/environments/ipd.rst
new file mode 100644
index 0000000000000000000000000000000000000000..98e55d0c72f29f026c2c5d27650f51f60a7e7601
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/environments/ipd.rst
@@ -0,0 +1,411 @@
+=================
+Iterated Prisoner's Dilemma
+=================
+
+The Iterated Prisoner's Dilemma environment provides a classic game theory setting for studying cooperation
+and competition between agents. This document describes the API for interacting with the IPD environment
+and its associated agent handler.
+
+Overview
+--------
+
+The Prisoner's Dilemma is a fundamental problem in game theory that demonstrates why two rational individuals might not
+cooperate, even when it appears in their best interest to do so. In the iterated version, the same two players
+repeatedly face the same dilemma, allowing for the development of trust or retaliation based on previous interactions.
+
+Our implementation follows the Multi-Agent Negotiation Environment standard, allowing it to be used with
+LLM agents through a text-based interface.
+
+Game Rules
+----------
+
+### Basic Premise
+
+The scenario behind the Prisoner's Dilemma is as follows:
+
+Two criminals are arrested and imprisoned. Each prisoner is in solitary confinement with no means of communicating with
+the other. The prosecutors lack sufficient evidence to convict the pair on the principal charge, but they have enough
+to convict both on a lesser charge. Simultaneously, the prosecutors offer each prisoner a bargain:
+
+- If both prisoners betray each other, each serves 2 years in prison (the "punishment" payoff)
+- If one betrays the other while the other remains silent, the betrayer goes free (the "temptation" payoff) while the
+ silent accomplice serves 3 years (the "sucker" payoff)
+- If both remain silent, each serves only 1 year in prison (the "reward" payoff)
+
+### Game Mechanics
+
+In our implementation, the choices are simplified to:
+- **C**: Cooperate (remain silent)
+- **D**: Defect (betray the other prisoner)
+
+Each round, both players simultaneously choose either C or D, and receive points based on the combination of their choices:
+
+- Both choose C: Both receive the "reward" payoff (3 points by default)
+- Both choose D: Both receive the "punishment" payoff (1 point by default)
+- One chooses C, one chooses D: The defector receives the "temptation" payoff (5 points by default), while the cooperator
+ receives the "sucker" payoff (0 points by default)
+
+### Example: Single Round
+
+Let's see how a single round plays out:
+
+1. Alice and Bob simultaneously make their choices
+2. If Alice chooses C and Bob chooses C:
+ - Alice receives 3 points
+ - Bob receives 3 points
+3. If Alice chooses C and Bob chooses D:
+ - Alice receives 0 points
+ - Bob receives 5 points
+4. If Alice chooses D and Bob chooses C:
+ - Alice receives 5 points
+ - Bob receives 0 points
+5. If Alice chooses D and Bob chooses D:
+ - Alice receives 1 point
+ - Bob receives 1 point
+
+### Iterated Game Structure
+
+The iterated version repeats this basic game for a fixed number of rounds. The key features are:
+
+1. Players know the total number of rounds in advance
+2. After each round, players learn what choice the other player made
+3. Players maintain a cumulative score across all rounds
+4. Players can adjust their strategy based on the history of previous interactions
+
+### Game Variations
+
+The IPD environment supports several variations through configuration parameters:
+
+#### Different Payoff Matrices
+
+The standard payoff values can be modified to create different incentive structures:
+- **Traditional PD**: reward=3, punishment=1, temptation=5, sucker=0
+- **Weak Temptation**: reward=3, punishment=1, temptation=4, sucker=0 (reduces the incentive to defect)
+- **Harsh Punishment**: reward=3, punishment=0, temptation=5, sucker=0 (increases the cost of mutual defection)
+- **Generous**: reward=4, punishment=2, temptation=5, sucker=1 (cushions the blow of being betrayed)
+
+#### Game Length Variations
+
+The number of rounds can significantly impact strategy:
+- **Short Games** (5-10 rounds): Incentivizes more defection, especially near the end
+- **Medium Games** (20-50 rounds): Allows for the development of tit-for-tat and forgiveness strategies
+- **Long Games** (100+ rounds): Favors steady cooperation with occasional "probing" defections
+
+### Common Strategies
+
+While not enforced by the environment, several well-known strategies can emerge:
+- **Always Cooperate**: Always choose C
+- **Always Defect**: Always choose D
+- **Tit for Tat**: Start with C, then copy what the opponent did in the previous round
+- **Forgiving Tit for Tat**: Like Tit for Tat, but occasionally cooperate even after being defected against
+- **Grudger**: Cooperate until the opponent defects once, then always defect
+- **Random**: Choose randomly between C and D
+
+IPDEnv
+------
+
+The ``IPDEnv`` class provides an interface to the Iterated Prisoner's Dilemma environment that follows the
+Multi-Agent Negotiation Environment standard.
+
+.. code-block:: python
+
+ class IPDEnv:
+ """
+ Iterated Prisoner's Dilemma environment following the MarlEnvironment standard.
+
+ In each round of the game, two agents simultaneously choose to either cooperate (C) or defect (D).
+ The payoffs are as follows:
+ - If both cooperate: Both receive the "reward" (usually 3 points)
+ - If both defect: Both receive the "punishment" (usually 1 point)
+ - If one cooperates and one defects: The defector receives the "temptation" (usually 5 points)
+ and the cooperator receives the "sucker" payoff (usually 0 points)
+
+ The game is played for a specified number of rounds.
+ """
+
+ def __init__(
+ self,
+ rounds_per_game: int = 10,
+ reward: float = 3.0, # Both cooperate
+ punishment: float = 1.0, # Both defect
+ temptation: float = 5.0, # Defector's reward when other cooperates
+ sucker: float = 0.0, # Cooperator's reward when other defects
+ random_seed: Optional[int] = None,
+ ):
+ """
+ Initialize the Iterated Prisoner's Dilemma environment.
+
+ Args:
+ rounds_per_game: Number of rounds to play
+ reward: Payoff when both agents cooperate
+ punishment: Payoff when both agents defect
+ temptation: Payoff for defecting when other agent cooperates
+ sucker: Payoff for cooperating when other agent defects
+ seed: Random seed for reproducibility
+ """
+ # ...
+
+ def reset(self) -> Dict[str, Dict[str, Any]]:
+ """
+ Reset the environment to an initial state and return the initial observation.
+
+ Returns:
+ observation (dict): A dictionary where keys are agent identifiers and values are observations.
+ """
+ # ...
+
+ def step(self, actions: Dict[str, str]) -> Tuple[Dict[str, Dict[str, Any]], bool, Dict[str, Any]]:
+ """
+ Take a step in the environment using the provided actions.
+
+ Args:
+ actions (dict): A dictionary where keys are agent identifiers and values are actions ('C' or 'D').
+
+ Returns:
+ observations (dict): A dictionary where keys are agent identifiers and values are observations.
+ done (bool): Whether the episode has ended.
+ info (dict): Additional information about the environment.
+ """
+ # ...
+
+Key Implementation Details
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The ``IPDEnv`` class implements several key features:
+
+1. **Two-Agent Support**: The environment tracks two agents ("alice" and "bob") and manages their interactions.
+
+2. **Round-Based Play**: The environment enforces turn structure and tracks game history.
+
+3. **Payoff Matrix**: The environment calculates rewards based on the standard prisoner's dilemma payoff matrix.
+
+4. **Observation Generation**: The environment generates detailed observations for each agent, including action history and rewards.
+
+5. **Game Termination**: The environment tracks game termination after the specified number of rounds.
+
+Observation Structure
+~~~~~~~~~~~~~~~~~~~~
+
+Each agent receives an observation dictionary with the following structure:
+
+.. code-block:: python
+
+ {
+ "current_round": int, # Current round number (0-indexed)
+ "rounds_per_game": int, # Total number of rounds in the game
+ "history": List[Dict], # Complete game history so far
+ "last_round_actions": Dict[str, str], # Actions from the previous round (if any)
+ "last_round_reward": float, # Reward received in the previous round (if any)
+ "total_reward": float, # Cumulative reward so far
+ "payoff_matrix": Dict[str, float], # The game's payoff matrix values
+ }
+
+Action Structure
+~~~~~~~~~~~~~~~
+
+Actions are simple strings:
+
+1. ``"C"`` for Cooperate
+2. ``"D"`` for Defect
+
+IPDAgent
+--------------
+
+The ``IPDAgent`` class implements the agent handler interface for the Iterated Prisoner's Dilemma, processing observations from the environment and generating actions through an LLM.
+
+.. code-block:: python
+
+ class IPDAgent:
+ """
+ Agent handler for Iterated Prisoner's Dilemma, implementing the AgentState interface
+ for the multi-agent negotiation standard.
+ """
+
+ def __init__(
+ self,
+ agent_id: str,
+ policy_id: str = "llm_policy",
+ system_prompt: Optional[str] = None,
+ max_errors: int = 3,
+ opponent_id: Optional[str] = None,
+ ):
+ """
+ Initialize the IPD agent handler.
+
+ Args:
+ agent_id: Identifier for this agent ("alice" or "bob")
+ policy_id: Identifier for the policy this agent uses
+ system_prompt: Optional custom system prompt for the LLM
+ max_errors: Maximum number of parsing errors before defaulting to cooperate
+ opponent_id: Optional identifier of the opponent (inferred if not provided)
+ """
+ # ...
+
+ def step(self, observation_from_env: Dict[str, Any], policy_output: str = None) -> Tuple[str, Dict[str, Any], str, bool, Dict[str, Any]]:
+ """
+ Update the agent state based on the observation and process the policy output.
+
+ Args:
+ observation_from_env: The observation from the environment
+ policy_output: The output from the policy (LLM response)
+
+ Returns:
+ policy_id: The policy identifier
+ policy_input: The input to the policy
+ action: The action to be sent to the environment
+ done: Whether the action is ready to be sent to the environment
+ info: Additional information about the agent
+ """
+ # ...
+
+Key Implementation Details
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The ``IPDAgent`` class implements several key features:
+
+1. **LLM Interaction**: The agent generates prompts for an LLM and processes the LLM's responses.
+
+2. **Action Extraction**: The agent parses the LLM's output to extract valid actions (C or D).
+
+3. **Error Handling**: The agent provides helpful error messages when parsing fails and defaults to cooperation after multiple failures.
+
+4. **History Tracking**: The agent maintains and provides the complete game history in its prompts.
+
+5. **Strategy Explanation**: The agent can extract and log the reasoning behind an LLM's decisions.
+
+Prompt Structure
+~~~~~~~~~~~~~~~
+
+The agent generates prompts that include:
+
+1. **System Prompt**: Instructions and context for the LLM, explaining its role and the rules of the Prisoner's Dilemma.
+
+2. **Game State Description**: A text description of the current game state, including:
+ - Current round number
+ - History of previous rounds (if any)
+ - Cumulative score
+
+3. **Action Request**: Instructions on how to format the response, requiring an explicit action tag.
+
+Example system prompt:
+
+.. code-block:: text
+
+ You are playing as Alice in an Iterated Prisoner's Dilemma game against Bob.
+ In each round, you must choose to either Cooperate (C) or Defect (D).
+
+ The payoffs are:
+ - If both players Cooperate: You each get 3 points
+ - If both players Defect: You each get 1 point
+ - If you Cooperate and Bob Defects: You get 0 points, Bob gets 5 points
+ - If you Defect and Bob Cooperates: You get 5 points, Bob gets 0 points
+
+ Your goal is to maximize your total points across all rounds.
+ The game will last for exactly 10 rounds, and both players know this.
+
+Example game state prompt:
+
+.. code-block:: text
+
+ Current round: 3/10
+
+ History:
+ Round 1: You chose C, Bob chose C. You earned 3 points.
+ Round 2: You chose C, Bob chose D. You earned 0 points.
+
+ Your total score so far: 3 points
+
+ What is your choice for round 3?
+ Please respond with C to cooperate or D to defect,
+ and explain your reasoning.
+
+Running IPD Games
+----------------------
+
+To run Iterated Prisoner's Dilemma games with LLM agents, you can use the following code structure:
+
+.. code-block:: python
+
+ from mllm.environments.ipd.ipd_game import IPDEnv
+ from mllm.environments.ipd.ipd_agent import IPDAgent
+ from mllm.run_matches import run_batched_matches
+
+ # Create environment
+ env = IPDEnv(
+ rounds_per_game=10,
+ reward=3.0,
+ punishment=1.0,
+ temptation=5.0,
+ sucker=0.0
+ )
+
+ # Create agent handlers
+ agent_handlers = {
+ "alice": IPDAgent(agent_id="alice"),
+ "bob": IPDAgent(agent_id="bob")
+ }
+
+ # Define policy mapping
+ policy_mapping = {
+ "llm_policy": my_llm_policy_function
+ }
+
+ # Run the game
+ game_results = run_batched_matches(
+ envs=[env],
+ agent_handlers_per_env=[agent_handlers],
+ policy_mapping=policy_mapping,
+ max_parallel_matches=1
+ )
+
+ # Process results
+ for result in game_results:
+ print(f"Game finished. Scores: {result['total_rewards']}")
+
+Statistics and Analysis
+----------------------
+
+The IPD environment includes utility functions for analyzing game outcomes:
+
+1. **Cooperation Rates**: Percentage of rounds where each agent cooperated.
+2. **Mutual Cooperation/Defection**: Percentage of rounds where both agents made the same choice.
+3. **Score Distribution**: Analysis of how points were accumulated over the game.
+
+These statistics can be calculated using the ``gather_ipd_statistics`` function:
+
+.. code-block:: python
+
+ from mllm.environments.ipd.ipd_statistics_funcs import gather_ipd_statistics
+
+ stats = gather_ipd_statistics(match_info, env_info)
+ print(f"Cooperation rates: {stats['cooperation_rate']}")
+ print(f"Mutual cooperation rate: {stats['mutual_cooperation_rate']}")
+ print(f"Mutual defection rate: {stats['mutual_defection_rate']}")
+
+Limitations and Considerations
+-----------------------------
+
+1. **Determinism**: The environment is deterministic, with randomness only in initialization if a seed is provided.
+
+2. **Limited Player Count**: The IPD environment only supports exactly two players.
+
+3. **Perfect Information**: Both players have perfect information about the game history.
+
+4. **Simultaneous Actions**: Both players act simultaneously, which requires adaptations for some LLM interfaces.
+
+5. **Fixed Game Length**: The total number of rounds is fixed and known to both players from the start.
+
+Advanced Usage
+------------
+
+For advanced usage, you can customize:
+
+1. **Payoff Matrix**: Modify reward values to create different incentive structures.
+
+2. **System Prompts**: Customize the LLM's understanding of the game and potential strategies.
+
+3. **Error Handling**: Adjust how the agent responds to invalid LLM outputs.
+
+4. **Analysis**: Create custom statistics gathering for specific research questions.
+
+5. **Integration**: Connect the IPD environment to other negotiation frameworks or tournament systems.
\ No newline at end of file
diff --git a/src_code_for_reproducibility/docs/source/index.rst b/src_code_for_reproducibility/docs/source/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..cdc1b79539342a9c95ca0cdd9219bce74a7b2c8a
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/index.rst
@@ -0,0 +1,22 @@
+Welcome to LLM Negotiation's documentation!
+===========================================
+This library is a collection of tools for training and evaluating LLM-based agents in multi-agent environments. It is designed to be easy to use and extend.
+
+.. toctree::
+ :maxdepth: 3
+ :caption: Contents:
+
+ installation
+ marl_standard
+ environments
+ launch
+ usage
+ modules
+ contributing
+
+Indices and tables
+==================
+
+* :ref:`genindex`
+* :ref:`modindex`
+* :ref:`search`
\ No newline at end of file
diff --git a/src_code_for_reproducibility/docs/source/installation.rst b/src_code_for_reproducibility/docs/source/installation.rst
new file mode 100644
index 0000000000000000000000000000000000000000..b148f25d92fd8308e9695f7c17c2b91fb0c9a2c6
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/installation.rst
@@ -0,0 +1,10 @@
+Installation
+===========
+
+To install the package, run:
+
+.. code-block:: bash
+
+ git clone https://github.com/yourusername/llm_negotiation.git
+ cd llm_negotiation
+ pip install -e .
\ No newline at end of file
diff --git a/src_code_for_reproducibility/docs/source/launch.rst b/src_code_for_reproducibility/docs/source/launch.rst
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src_code_for_reproducibility/docs/source/marl_standard.rst b/src_code_for_reproducibility/docs/source/marl_standard.rst
new file mode 100644
index 0000000000000000000000000000000000000000..b5ea5529892c611b34255645ec68537a236754cf
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/marl_standard.rst
@@ -0,0 +1,141 @@
+=================
+Abstract Standard for Multi-Agent Negotiation Environments
+=================
+
+Multi-Agent Negotiation Environments require more features than gymnasium environments in order to be used as interfaces in general game running code.
+The two fundamental differences between gymnasium environments and Multi-Agent Negotiation Environments are:
+
+1. Response from the LLM is a text action, not a discrete action. Therefore, appropriate parsing of the text is required. The model may need to be run multiple times to get the full action.
+ This is why we introduce the `AgentHandler` class, which is responsible for parsing the LLM's response.
+2. The environment needs to be able to handle multi-agent interactions.
+ This is why we introduce the `NegotiationEnvironment` class, which is responsible for handling the multi-agent interactions.
+3. MARL environments are complex to describe. In different contexts, the same environment may be described differently. Therefore, both the environement and the agent handlers are
+ responsible for describing a particular trajectory. This information is given by the `get_log_info` method.
+4. There might be a lot of overlap between the neural networks used by each agent. For instance, the same model may be used for all agents. This motivates a requirement for a
+ policy identifier for each agent.
+
+Taking inspiration from the `gymnasium `_ library, we introduce a new standard for Multi-Agent Negotiation Environments.
+
+Our standard is based on the following features:
+
+Environments are of the form:
+
+.. code-block:: python
+
+ class MarlEnvironment():
+
+ def __init__(self):
+ """Initialize the environment."""
+ pass
+
+ def reset(self):
+ """Reset the environment to an initial state and return the initial observation.
+ Returns:
+ observation (dict): A dictionary where keys are agent identifiers and values are observations.
+ """
+ # (...)
+ return observation
+
+ def step(self, actions):
+ """Take a step in the environment using the provided actions.
+
+ Args:
+ actions (dict): A dictionary where keys are agent identifiers and values are actions.
+
+ Returns:
+ observations (dict): A dictionary where keys are agent identifiers and values are observations.
+ reward (dict): A dictionary where keys are agent identifiers and values are rewards.
+ done (bool): Whether the episode has ended.
+ info (dict): Additional information about the environment.
+ """
+ # (...)
+ return observations, done, info
+
+ def get_log_info(self):
+ """Get additional information about the environment. This information is used to log the game.
+ Returns:
+ log_info (dict): Information about the environment required to log the game.
+ """
+ # (...)
+ return log_info
+
+ def render(self):
+ """Render the current state of the environment."""
+ pass
+
+ def close(self):
+ """Perform any necessary cleanup."""
+ pass
+
+
+ class AgentState():
+
+ def __init__(self):
+ """Initialize the agent state."""
+ pass
+
+ def step(self, observation_from_env, policy_output=None):
+ """Update the agent state based on the observation and action.
+ The action is the output of the LLM.
+ """
+
+ Args:
+ observation_from_env (dict): The observation of the environment.
+ policy_output : The output of the policy.
+
+ Returns:
+ policy_id (str): The policy identifier.
+ policy_input (dict): The input to the policy.
+ action : The official action to be sent to the environment.
+ done (bool): Whether the LLM action is ready to be sent to the environment.
+ info (dict): Additional information about the agent.
+ """
+ # (...)
+ return policy_id, policy_input, action, done, info
+
+ def get_log_info(self):
+ """Get information about the agent required to log a trajectory.
+ Returns:
+ log_info (dict): Information about the agent required to log a trajectory.
+ """
+ # (...)
+ return log_info
+
+ def render(self):
+ """Render the current state of the environment."""
+ pass
+
+ def close(self):
+ """Perform any necessary cleanup."""
+ pass
+
+
+Implicitely, the keys of the `observations` in the `step` method of the `MarlEnvironment` interface represent the set of agents from which an action is expected at the current step. The next step should only expect actions from the agents in the `observations` dictionary.
+
+As you can see, both classes have a `get_log_info` method. This method is used to log the game. It returns a dictionary with keys being the agent identifiers and values being the information to log. The reason we need this is because the environment and the agent handler may need to log different information. It makes it easier to log from the perspective of each agent. The core environment class should not need to know about the details of the agent handler.
+
+
+
+Running Environments in Parallel
+--------------------------------
+This standard allows the use of the `run_batched_matches` function (TODO: link) to run environments in an efficient way. The core idea is to batch the policy calls for all agents in the environment.
+
+.. note::
+ The ``run_batched_matches`` function allows you to run multiple negotiation games, or "matches," in parallel.
+ After each environment is initialized, the function continuously loops over all active matches and checks which agents
+ are still pending actions. Each agent's logic can require multiple calls to the policy (e.g., an LLM) before an action
+ becomes "ready" to be sent to the environment. (For instance, an agent might need multiple policy calls before having a string which can be parsed into a valid action.) While an agent is waiting for a policy output, these calls for all agents across all matches are grouped together by unique policy identifier and processed in batch for efficiency. This is the core functionality of the ``run_batched_matches`` function.
+
+ Only once all actions from the required agents at a given step for an environment are ready does the function make a single ``env.step(...)`` call; this ensures
+ every match moves forward in lockstep for all its active agents. As soon as an environment signals it is done, the function
+ retrieves logged information from both the environment and the agent states before removing this match from the active set.
+
+ If there are more matches waiting to be processed, they are then started one by one to maintain the specified degree of parallelism.
+ This batching approach provides an efficient mechanism to handle multi-agent or multi-policy environments, ensuring minimal
+ overhead and a clear, unified flow for stepping through matches.
+
+Here is a diagram that shows how the `run_batched_matches` function works at a high level:
+
+.. image:: media/runbatch.png
+ :alt: Alternate text for the image
+ :width: 1000px
diff --git a/src_code_for_reproducibility/docs/source/media/runbatch.png b/src_code_for_reproducibility/docs/source/media/runbatch.png
new file mode 100644
index 0000000000000000000000000000000000000000..e7572fa514d9e029a6c08e7061fa88b03bc63de2
Binary files /dev/null and b/src_code_for_reproducibility/docs/source/media/runbatch.png differ
diff --git a/src_code_for_reproducibility/docs/source/src.environments.dond.dond_agent.rst b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_agent.rst
new file mode 100644
index 0000000000000000000000000000000000000000..8fab765a9c7e749bd446533fdddb5fa5b55e6635
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_agent.rst
@@ -0,0 +1,7 @@
+src.environments.dond.dond\_agent module
+========================================
+
+.. automodule:: src.environments.dond.dond_agent
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.environments.dond.dond_game.rst b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_game.rst
new file mode 100644
index 0000000000000000000000000000000000000000..d0e595aad169a5a8456f83afe5029e7475d7c9e7
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_game.rst
@@ -0,0 +1,7 @@
+src.environments.dond.dond\_game module
+=======================================
+
+.. automodule:: src.environments.dond.dond_game
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.environments.dond.dond_log_funcs.rst b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_log_funcs.rst
new file mode 100644
index 0000000000000000000000000000000000000000..cf96327d1bcbc7f0f8785804a49a6975eef889c2
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_log_funcs.rst
@@ -0,0 +1,7 @@
+src.environments.dond.dond\_log\_funcs module
+=============================================
+
+.. automodule:: src.environments.dond.dond_log_funcs
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.environments.dond.dond_player.rst b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_player.rst
new file mode 100644
index 0000000000000000000000000000000000000000..bab97f1009eb2d5c4e387ac6a83982a51e33c9e3
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_player.rst
@@ -0,0 +1,7 @@
+src.environments.dond.dond\_agent module
+=========================================
+
+.. automodule:: src.environments.dond.dond_agent
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.environments.dond.dond_return_funcs.rst b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_return_funcs.rst
new file mode 100644
index 0000000000000000000000000000000000000000..e8084f8f9a291efe75e032183ebd5eeab58f5e41
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_return_funcs.rst
@@ -0,0 +1,7 @@
+src.environments.dond.dond\_return\_funcs module
+================================================
+
+.. automodule:: src.environments.dond.dond_return_funcs
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.environments.dond.dond_training_data_funcs.rst b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_training_data_funcs.rst
new file mode 100644
index 0000000000000000000000000000000000000000..cf31d696a3ed580e24f3c5dffd6f7a2851d16320
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_training_data_funcs.rst
@@ -0,0 +1,7 @@
+src.environments.dond.dond\_training\_data\_funcs module
+========================================================
+
+.. automodule:: src.environments.dond.dond_training_data_funcs
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.environments.env_imports.rst b/src_code_for_reproducibility/docs/source/src.environments.env_imports.rst
new file mode 100644
index 0000000000000000000000000000000000000000..4354ba27eee9f0e0fa3f4f0e5d9131c256a4be57
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.environments.env_imports.rst
@@ -0,0 +1,7 @@
+src.environments.env\_imports module
+====================================
+
+.. automodule:: src.environments.env_imports
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.environments.environment_imports.rst b/src_code_for_reproducibility/docs/source/src.environments.environment_imports.rst
new file mode 100644
index 0000000000000000000000000000000000000000..d22c53e31cd1c7c064955900c19f34ac51c7006f
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.environments.environment_imports.rst
@@ -0,0 +1,7 @@
+src.environments.environment\_imports module
+============================================
+
+.. automodule:: src.environments.environment_imports
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_agent.rst b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_agent.rst
new file mode 100644
index 0000000000000000000000000000000000000000..4845b371089c529493f70de77ceaee0b7500571b
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_agent.rst
@@ -0,0 +1,7 @@
+src.environments.ipd.ipd\_agent module
+======================================
+
+.. automodule:: src.environments.ipd.ipd_agent
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_game.rst b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_game.rst
new file mode 100644
index 0000000000000000000000000000000000000000..ede471ef9675c780410189fcf63df0c1a05496d0
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_game.rst
@@ -0,0 +1,7 @@
+src.environments.ipd.ipd\_game module
+=====================================
+
+.. automodule:: src.environments.ipd.ipd_game
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_statistics_funcs.rst b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_statistics_funcs.rst
new file mode 100644
index 0000000000000000000000000000000000000000..5f54afac07c4d477067ef4c2bf5d883b236cf5fc
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_statistics_funcs.rst
@@ -0,0 +1,7 @@
+src.environments.ipd.ipd\_statistics\_funcs module
+==================================================
+
+.. automodule:: src.environments.ipd.ipd_statistics_funcs
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.environments.rst b/src_code_for_reproducibility/docs/source/src.environments.rst
new file mode 100644
index 0000000000000000000000000000000000000000..221ed1c07ebea145cd23bc06c6474d34b1d8a33e
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.environments.rst
@@ -0,0 +1,25 @@
+src.environments package
+========================
+
+.. automodule:: src.environments
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Subpackages
+-----------
+
+.. toctree::
+ :maxdepth: 4
+
+ src.environments.dond
+ src.environments.ipd
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ src.environments.env_imports
+ src.environments.environment_imports
diff --git a/src_code_for_reproducibility/docs/source/src.experiments.arithmetic_test.rst b/src_code_for_reproducibility/docs/source/src.experiments.arithmetic_test.rst
new file mode 100644
index 0000000000000000000000000000000000000000..68e0f5da020aee80cc8895ca650a6067317f4bcd
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.experiments.arithmetic_test.rst
@@ -0,0 +1,7 @@
+src.experiments.arithmetic\_test module
+=======================================
+
+.. automodule:: src.experiments.arithmetic_test
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.experiments.last_completion.rst b/src_code_for_reproducibility/docs/source/src.experiments.last_completion.rst
new file mode 100644
index 0000000000000000000000000000000000000000..1b868ee566283d662a51387046bc070a131f5222
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.experiments.last_completion.rst
@@ -0,0 +1,7 @@
+src.experiments.last\_completion module
+=======================================
+
+.. automodule:: src.experiments.last_completion
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.generation.rst b/src_code_for_reproducibility/docs/source/src.generation.rst
new file mode 100644
index 0000000000000000000000000000000000000000..14bb2b1364da7067aed5c37e3c77d091d20f011b
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.generation.rst
@@ -0,0 +1,15 @@
+src.generation package
+======================
+
+.. automodule:: src.generation
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ src.generation.run_games
diff --git a/src_code_for_reproducibility/docs/source/src.models.dummy_local_llm.rst b/src_code_for_reproducibility/docs/source/src.models.dummy_local_llm.rst
new file mode 100644
index 0000000000000000000000000000000000000000..13b40bd388e445fa60a3c3fc2e089ad89c452dbd
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.models.dummy_local_llm.rst
@@ -0,0 +1,7 @@
+src.models.dummy\_local\_llm module
+===================================
+
+.. automodule:: src.models.dummy_local_llm
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.models.local_llm.rst b/src_code_for_reproducibility/docs/source/src.models.local_llm.rst
new file mode 100644
index 0000000000000000000000000000000000000000..5c2eebb05e64919d1915eeb63dc18f5e9a36eb2c
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.models.local_llm.rst
@@ -0,0 +1,7 @@
+src.models.local\_llm module
+============================
+
+.. automodule:: src.models.local_llm
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.models.new_local_llm.rst b/src_code_for_reproducibility/docs/source/src.models.new_local_llm.rst
new file mode 100644
index 0000000000000000000000000000000000000000..d65981b7a6bbf99327fb467a63644bbe98d137f4
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.models.new_local_llm.rst
@@ -0,0 +1,7 @@
+src.models.new\_local\_llm module
+=================================
+
+.. automodule:: src.models.new_local_llm
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.models.rst b/src_code_for_reproducibility/docs/source/src.models.rst
new file mode 100644
index 0000000000000000000000000000000000000000..d03983340a5b0317354d1895df709277d5a4baed
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.models.rst
@@ -0,0 +1,20 @@
+src.models package
+==================
+
+.. automodule:: src.models
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ src.models.dummy_local_llm
+ src.models.local_llm
+ src.models.new_local_llm
+ src.models.server_llm
+ src.models.updatable_worker
+ src.models.vllm_worker_wrap
diff --git a/src_code_for_reproducibility/docs/source/src.models.updatable_worker.rst b/src_code_for_reproducibility/docs/source/src.models.updatable_worker.rst
new file mode 100644
index 0000000000000000000000000000000000000000..ee05dfbe7dd407eed4e275f525c04ca6f2c857ae
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.models.updatable_worker.rst
@@ -0,0 +1,7 @@
+src.models.updatable\_worker module
+===================================
+
+.. automodule:: src.models.updatable_worker
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.training.reinforce_training.rst b/src_code_for_reproducibility/docs/source/src.training.reinforce_training.rst
new file mode 100644
index 0000000000000000000000000000000000000000..5daf4b7250022f523242d6239d0921f362df6d24
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.training.reinforce_training.rst
@@ -0,0 +1,7 @@
+src.training.reinforce\_training module
+=======================================
+
+.. automodule:: src.training.reinforce_training
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.training.rl_convs_processing.rst b/src_code_for_reproducibility/docs/source/src.training.rl_convs_processing.rst
new file mode 100644
index 0000000000000000000000000000000000000000..cf5db1aa0cb6d010fc70f86c341467ba5e9b485e
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.training.rl_convs_processing.rst
@@ -0,0 +1,7 @@
+src.training.rl\_convs\_processing module
+=========================================
+
+.. automodule:: src.training.rl_convs_processing
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.training.rst b/src_code_for_reproducibility/docs/source/src.training.rst
new file mode 100644
index 0000000000000000000000000000000000000000..50539fcda2bffa46a72eb48874a7532bf296ff27
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.training.rst
@@ -0,0 +1,19 @@
+src.training package
+====================
+
+.. automodule:: src.training
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ src.training.ppo_train
+ src.training.ppo_train_value_head
+ src.training.reinforce_training
+ src.training.rl_convs_processing
+ src.training.train_main
diff --git a/src_code_for_reproducibility/docs/source/src.utils.log_gpu_usage.rst b/src_code_for_reproducibility/docs/source/src.utils.log_gpu_usage.rst
new file mode 100644
index 0000000000000000000000000000000000000000..44b83082b6eb027ef402603e034c712ccc2cbfcc
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.utils.log_gpu_usage.rst
@@ -0,0 +1,7 @@
+src.utils.log\_gpu\_usage module
+================================
+
+.. automodule:: src.utils.log_gpu_usage
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/src.utils.rst b/src_code_for_reproducibility/docs/source/src.utils.rst
new file mode 100644
index 0000000000000000000000000000000000000000..4f5cb352cc9ec645c968d0ae99798d47c018c750
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.utils.rst
@@ -0,0 +1,24 @@
+src.utils package
+=================
+
+.. automodule:: src.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ src.utils.common_imports
+ src.utils.export_ppo_training_set
+ src.utils.extra_stats
+ src.utils.inherit_args
+ src.utils.log_gpu_usage
+ src.utils.log_statistics
+ src.utils.model_to_cpu
+ src.utils.parallel_shuffle
+ src.utils.quick_stats
+ src.utils.update_start_epoch
diff --git a/src_code_for_reproducibility/docs/source/src.utils.update_start_epoch.rst b/src_code_for_reproducibility/docs/source/src.utils.update_start_epoch.rst
new file mode 100644
index 0000000000000000000000000000000000000000..72cbad9bd09e056213a2e4cd00a6ba624be333cb
--- /dev/null
+++ b/src_code_for_reproducibility/docs/source/src.utils.update_start_epoch.rst
@@ -0,0 +1,7 @@
+src.utils.update\_start\_epoch module
+=====================================
+
+.. automodule:: src.utils.update_start_epoch
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/src_code_for_reproducibility/docs/source/usage.rst b/src_code_for_reproducibility/docs/source/usage.rst
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src_code_for_reproducibility/markov_games/__pycache__/__init__.cpython-311.pyc b/src_code_for_reproducibility/markov_games/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..87d439427c824e863dc4f7216198d6170539d30e
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/__pycache__/__init__.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a3a6c09c6f7608acfeddebc65fdb4149e8c3cb7
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/__init__.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/__pycache__/agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/agent.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..efe246c29687cefb7069674e57114ea38dab252a
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/agent.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/__pycache__/alternative_actions_runner.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/alternative_actions_runner.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5788779130aac0fc02a9c1b2ac985b4cd5427eb1
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/alternative_actions_runner.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/__pycache__/gather_and_export_utils.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/gather_and_export_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ebad9ae7dd0fb04490740075b6d765f0d0237b2c
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/gather_and_export_utils.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/__pycache__/group_timesteps.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/group_timesteps.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..68ef22b6f0c70af174960ab268da175e89c9d7e3
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/group_timesteps.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/__pycache__/linear_runner.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/linear_runner.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d4da97bca809a1e8526f5dc00c3ee23d8b5c021
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/linear_runner.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/__pycache__/markov_game.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/markov_game.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9be34b4a5f90d0e3c622100495120366881e0b0
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/markov_game.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-311.pyc b/src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..56bfac1adc30d8ee8d1e0defc19d26ead0758f4b
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-311.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bede63b1b249138a212198e9d178e47b63db35d2
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/__pycache__/run_markov_games.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/run_markov_games.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..caebf44fc211bb439d6ffe745a8fe8b08d71f568
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/run_markov_games.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b3d76acd3764faebf2875304319edea4aa5e856
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_agent.py b/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ff402e23224fc7961d9e6796c40daaf2ab4bbaa
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_agent.py
@@ -0,0 +1,259 @@
+from typing import Dict, List, Tuple, Optional, Any
+import copy
+
+class DiplomacyAgent:
+ """Agent handler for Diplomacy game that follows the MARL standard.
+
+ This class is responsible for parsing LLM output into valid Diplomacy orders,
+ managing the agent state, and providing information for logging.
+ """
+
+ def __init__(self, policy_id: str, power_name: str, random_valid_move=False):
+ """Initialize the agent handler for a power in the Diplomacy game.
+
+ Args:
+ power_name: The name of the power this agent controls (e.g., 'FRANCE', 'ENGLAND')
+ policy_id: The identifier for the policy this agent uses
+ random_valid_move: If True, will select random valid moves instead of using LLM (default: False)
+ """
+ self.policy_id = policy_id
+ self.power_name = power_name
+ self.orders = []
+ self.wait = True
+ self.processing_state = "WAITING_FOR_ORDERS"
+ self.parsed_orders = []
+ self.order_status = {}
+ self.message_history = []
+ self.random_valid_move = random_valid_move
+
+ def step(self, observation_from_env, policy_output=None):
+ """Update the agent state based on the observation and LLM output.
+
+ Args:
+ observation_from_env: The observation from the environment
+ policy_output: The output from the LLM
+
+ Returns:
+ policy_id: The policy identifier
+ policy_input: The input to the policy
+ action: The official action to be sent to the environment
+ done: Whether the LLM action is ready to be sent to the environment
+ info: Additional information about the agent
+ """
+ info = {}
+
+ # If random_valid_move is enabled, select random valid moves
+ if self.random_valid_move:
+ valid_orders = self._select_random_valid_moves(observation_from_env)
+ self.orders = valid_orders
+ self.wait = False
+ action = {
+ "orders": valid_orders,
+ "wait": False
+ }
+ return self.policy_id, {}, action, True, info
+
+ # If no policy output, this is the initial step - prepare prompt
+ if policy_output is None:
+ # Create initial prompt for the LLM
+ phase = observation_from_env.get('phase', '')
+ units = observation_from_env.get('units', {}).get(self.power_name, [])
+ centers = observation_from_env.get('centers', {}).get(self.power_name, [])
+ orderable_locations = observation_from_env.get('orderable_locations', {})
+
+ prompt = self._create_prompt(phase, units, centers, orderable_locations)
+
+ return self.policy_id, {"prompt": prompt}, None, False, info
+
+ # Process the LLM output to extract orders
+ success, parsed_orders = self._parse_llm_output(policy_output)
+ self.parsed_orders = parsed_orders
+
+ if not success:
+ # Need more information from LLM
+ clarification_prompt = self._create_clarification_prompt(policy_output, parsed_orders)
+ return self.policy_id, {"prompt": clarification_prompt}, None, False, info
+
+ # Validate if the orders are valid for the current phase
+ valid_orders = self._validate_orders(parsed_orders, observation_from_env)
+
+ if valid_orders:
+ # Orders are valid, prepare action for environment
+ self.orders = valid_orders
+ self.wait = False
+ action = {
+ "orders": valid_orders,
+ "wait": False
+ }
+ return self.policy_id, {}, action, True, info
+ else:
+ # Orders are invalid, ask for new ones
+ error_prompt = self._create_error_prompt(parsed_orders, observation_from_env)
+ return self.policy_id, {"prompt": error_prompt}, None, False, info
+
+ def _create_prompt(self, phase, units, centers, orderable_locations):
+ """Create the initial prompt for the LLM.
+
+ Args:
+ phase: The current game phase
+ units: List of units controlled by this power
+ centers: List of supply centers controlled by this power
+ orderable_locations: List of locations where orders can be issued
+
+ Returns:
+ A prompt string for the LLM
+ """
+ prompt = f"You are playing as {self.power_name} in Diplomacy. The current phase is {phase}.\n\n"
+ prompt += f"Your units: {', '.join(units)}\n"
+ prompt += f"Your supply centers: {', '.join(centers)}\n"
+ prompt += f"Locations you can order: {', '.join(orderable_locations)}\n\n"
+
+ if phase.endswith('M'): # Movement phase
+ prompt += "Please provide orders for your units in the form:\n"
+ prompt += "- A LON H (hold)\n"
+ prompt += "- F NTH - NWY (move)\n"
+ prompt += "- A WAL S F LON (support)\n"
+ prompt += "- F NWG C A NWY - EDI (convoy)\n"
+ elif phase.endswith('R'): # Retreat phase
+ prompt += "Please provide retreat orders for your dislodged units:\n"
+ prompt += "- A PAR R MAR (retreat to MAR)\n"
+ prompt += "- A PAR D (disband)\n"
+ elif phase.endswith('A'): # Adjustment phase
+ if len(units) < len(centers):
+ prompt += "You can build units. Please provide build orders:\n"
+ prompt += "- A PAR B (build army in PAR)\n"
+ prompt += "- F BRE B (build fleet in BRE)\n"
+ prompt += "- WAIVE (waive a build)\n"
+ elif len(units) > len(centers):
+ prompt += "You must remove units. Please provide disbandment orders:\n"
+ prompt += "- A PAR D (disband army in PAR)\n"
+ prompt += "- F BRE D (disband fleet in BRE)\n"
+
+ prompt += "\nProvide your orders as a list, one per line."
+ return prompt
+
+ def _parse_llm_output(self, llm_output):
+ """Parse the LLM output to extract orders.
+
+ Args:
+ llm_output: The raw output from the LLM
+
+ Returns:
+ success: Whether parsing was successful
+ parsed_orders: List of parsed orders
+ """
+ # Simple parsing for now - extract lines that look like orders
+ lines = llm_output.strip().split('\n')
+ orders = []
+
+ for line in lines:
+ # Remove list markers, hyphens, etc.
+ line = line.strip('- *•').strip()
+
+ # Skip empty lines and lines that don't look like orders
+ if not line or line.startswith('I ') or line.startswith('Let\'s'):
+ continue
+
+ # Check if it looks like a Diplomacy order
+ if (' H' in line or ' -' in line or ' S ' in line or ' C ' in line or
+ ' R ' in line or ' D' in line or ' B' in line or line == 'WAIVE'):
+ orders.append(line)
+
+ return len(orders) > 0, orders
+
+ def _validate_orders(self, orders, observation):
+ """Validate if the orders are valid for the current phase.
+
+ Args:
+ orders: List of orders to validate
+ observation: Current observation from the environment
+
+ Returns:
+ List of valid orders or None if invalid
+ """
+ # For simplicity, we'll assume all parsed orders are valid
+ # In a real implementation, we would use the game's validation logic
+ return orders
+
+ def _create_clarification_prompt(self, previous_output, parsed_orders):
+ """Create a prompt asking for clarification when orders couldn't be parsed.
+
+ Args:
+ previous_output: The previous LLM output
+ parsed_orders: Any orders that were successfully parsed
+
+ Returns:
+ A prompt string for the LLM
+ """
+ prompt = f"I couldn't fully understand your orders for {self.power_name}. "
+
+ if parsed_orders:
+ prompt += f"I understood these orders:\n"
+ for order in parsed_orders:
+ prompt += f"- {order}\n"
+
+ prompt += "\nPlease provide clear, valid Diplomacy orders in the format:\n"
+ prompt += "- A LON H\n- F NTH - NWY\n- etc.\n"
+ return prompt
+
+ def _create_error_prompt(self, invalid_orders, observation):
+ """Create a prompt when orders are invalid.
+
+ Args:
+ invalid_orders: The invalid orders
+ observation: Current observation from the environment
+
+ Returns:
+ A prompt string for the LLM
+ """
+ prompt = f"The following orders for {self.power_name} are invalid:\n"
+ for order in invalid_orders:
+ prompt += f"- {order}\n"
+
+ prompt += "\nPlease provide valid orders for your units."
+ return prompt
+
+ def get_log_info(self):
+ """Get information about the agent required to log a trajectory.
+
+ Returns:
+ log_info: Information about the agent required to log a trajectory.
+ """
+ return {
+ "power_name": self.power_name,
+ "orders": self.orders,
+ "wait": self.wait,
+ "parsing_state": self.processing_state,
+ "message_history": self.message_history
+ }
+
+ def render(self):
+ """Render the current state of the agent."""
+ print(f"Power: {self.power_name}")
+ print(f"Orders: {self.orders}")
+ print(f"Wait: {self.wait}")
+
+ def close(self):
+ """Perform any necessary cleanup."""
+ pass
+
+ def _select_random_valid_moves(self, observation):
+ """Select random valid moves for all units.
+
+ Args:
+ observation: Current observation from the environment
+
+ Returns:
+ List of valid orders
+ """
+ import random
+
+ possible_orders = observation.get('possible_orders', {})
+ valid_orders = []
+
+ # For each location with possible orders, select one randomly
+ for location, orders in possible_orders.items():
+ if orders: # If there are any possible orders for this location
+ valid_orders.append(random.choice(orders))
+
+ return valid_orders
\ No newline at end of file
diff --git a/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_env.py b/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b72612c43f2535d353b0157ce72a9b79c23cbb3
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_env.py
@@ -0,0 +1,230 @@
+from typing import Dict, List, Tuple, Optional, Any
+from diplomacy import Game
+import random
+
+class DiplomacyEnv:
+ """Multi-Agent Reinforcement Learning environment for Diplomacy.
+
+ This class wraps the Diplomacy game engine to provide an interface
+ compliant with the MARL standard.
+ """
+
+ def __init__(self, random_seed=None, map_name="standard", game_id=None, rules=None, max_steps=50):
+ """Initialize the Diplomacy environment.
+
+ Args:
+ map_name: The name of the map to use (default: "standard")
+ game_id: Optional game ID
+ rules: Optional rules to apply to the game
+ max_steps: Maximum number of steps before forcing game end (default: 10)
+ """
+ self.random_seed = random_seed
+ self.map_name = map_name
+ self.game_id = game_id
+ self.rules = rules or []
+ self.game = None
+ self.active_powers = []
+ self.render_mode = None
+ self.max_steps = max_steps
+ self.current_steps = 0
+
+ def reset(self):
+ """Reset the environment to an initial state and return the initial observation.
+
+ Returns:
+ observation: A dictionary where keys are agent identifiers and values are observations.
+ """
+ # Initialize a new game
+ self.game = Game(game_id=self.game_id, map_name=self.map_name)
+
+ # Apply rules
+ for rule in self.rules:
+ self.game.add_rule(rule)
+
+ # Determine active powers (not eliminated)
+ self.active_powers = [name for name, power in self.game.powers.items()
+ if not power.is_eliminated()]
+
+ # Reset step counter
+ self.current_steps = 0
+
+ # Create initial observations for all powers
+ observations = {}
+ for power_name in self.active_powers:
+ observations[power_name] = self._create_observation(power_name)
+
+ return observations
+
+ def step(self, actions):
+ """Take a step in the environment using the provided actions.
+
+ Args:
+ actions: A dictionary where keys are agent identifiers and values are actions.
+
+ Returns:
+ observations: A dictionary where keys are agent identifiers and values are observations.
+ done: Whether the episode has ended.
+ info: Additional information about the environment.
+ """
+ print(f"stepping {self.current_steps}")
+ self.current_steps += 1
+ # Apply actions (orders) for each power
+ for power_name, action in actions.items():
+ if power_name in self.active_powers:
+ orders = action.get("orders", [])
+ wait = action.get("wait", True)
+
+ # Set orders for the power
+ if orders:
+ self.game.set_orders(power_name, orders)
+
+ # Set wait flag
+ self.game.set_wait(power_name, wait)
+
+ # Check if all active powers are ready to proceed
+ if self.game.does_not_wait():
+ # Process the current phase
+ self.game.process()
+
+
+ # Update active powers list after processing
+ self.active_powers = [name for name, power in self.game.powers.items()
+ if not power.is_eliminated()]
+
+ # Create observations for all active powers
+ observations = {}
+ for power_name in self.active_powers:
+ observations[power_name] = self._create_observation(power_name)
+
+ # Check if the game is done (either naturally or due to max steps)
+ done = self.game.is_game_done or self.current_steps >= self.max_steps
+
+ # Create info dict
+ info = {
+ "phase": self.game.get_current_phase(),
+ "active_powers": self.active_powers,
+ "centers": self.game.get_centers(),
+ "units": self.game.get_units(),
+ "current_steps": self.current_steps,
+ "max_steps_reached": self.current_steps >= self.max_steps
+ }
+
+ return observations, done, info
+
+ def _create_observation(self, power_name):
+ """Create observation for a specific power.
+
+ Args:
+ power_name: The name of the power
+
+ Returns:
+ An observation dictionary
+ """
+ observation = {
+ "phase": self.game.get_current_phase(),
+ "units": self.game.get_units(),
+ "centers": self.game.get_centers(),
+ "orderable_locations": self.game.get_orderable_locations(power_name),
+ "order_status": self.game.get_order_status(power_name),
+ "possible_orders": self._get_possible_orders_for_power(power_name)
+ }
+ return observation
+
+ def _get_possible_orders_for_power(self, power_name):
+ """Get all possible orders for a power's units.
+
+ Args:
+ power_name: The name of the power
+
+ Returns:
+ A dictionary mapping units to their possible orders
+ """
+ all_possible_orders = self.game.get_all_possible_orders()
+
+ # Filter for only the locations where this power has units
+ power_units = self.game.get_units(power_name)
+ power_unit_locations = [unit[2:] for unit in power_units]
+
+ # For retreat phases, include retreating units
+ if self.game.phase_type == 'R':
+ power = self.game.get_power(power_name)
+ power_unit_locations.extend([unit[2:] for unit in power.retreats])
+
+ # For adjustment phases, include buildable locations
+ elif self.game.phase_type == 'A':
+ power = self.game.get_power(power_name)
+ # If we have more centers than units, we can build
+ if len(power.centers) > len(power.units):
+ buildable_sites = self.game._build_sites(power)
+ power_unit_locations.extend(buildable_sites)
+ # If we have more units than centers, we need to remove
+ elif len(power.units) > len(power.centers):
+ # All units are candidates for removal
+ pass
+
+ # Filter the possible orders to only those for this power's units/locations
+ power_possible_orders = {}
+ for loc, orders in all_possible_orders.items():
+ if loc[:3] in power_unit_locations:
+ power_possible_orders[loc] = orders
+
+ return power_possible_orders
+
+ def get_log_info(self):
+ """Get additional information about the environment for logging.
+
+ Returns:
+ log_info: Information about the environment required to log the game.
+ """
+ if not self.game:
+ return {}
+
+ return {
+ "game_id": self.game.game_id,
+ "phase": self.game.get_current_phase(),
+ "map_name": self.game.map_name,
+ "centers": self.game.get_centers(),
+ "units": self.game.get_units(),
+ "powers": {name: {
+ "units": power.units,
+ "centers": power.centers,
+ "is_eliminated": power.is_eliminated(),
+ "order_status": self.game.get_order_status(name)
+ } for name, power in self.game.powers.items()},
+ "orders": self.game.get_orders(),
+ "active_powers": self.active_powers,
+ "is_game_done": self.game.is_game_done,
+ "outcome": self.game.outcome if self.game.is_game_done else None
+ }
+
+ def render(self, mode='human'):
+ """Render the current state of the environment.
+
+ Args:
+ mode: The rendering mode ('human', 'svg', etc.)
+
+ Returns:
+ The rendered image if applicable
+ """
+ self.render_mode = mode
+ if self.game:
+ if mode == 'human':
+ # Just print basic game state
+ print(f"Game: {self.game.game_id}")
+ print(f"Phase: {self.game.get_current_phase()}")
+ print(f"Active Powers: {self.active_powers}")
+ print("Supply Centers:")
+ for power_name, centers in self.game.get_centers().items():
+ print(f" {power_name}: {centers}")
+ print("Units:")
+ for power_name, units in self.game.get_units().items():
+ print(f" {power_name}: {units}")
+ return None
+ elif mode == 'svg':
+ # Return SVG representation
+ return self.game.render(output_format='svg')
+ return None
+
+ def close(self):
+ """Perform any necessary cleanup."""
+ self.game = None
\ No newline at end of file
diff --git a/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging.py b/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f60e82d9738c116c0b8b8d3f7818eddebb18fa2
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging.py
@@ -0,0 +1,360 @@
+import os
+import json
+from utils.common_imports import *
+
+
+
+def diplomacy_log_match(
+ path,
+ agents_log_info,
+ env_log_info,
+ metrics_func=None,
+ metrics_func_args=None
+ ):
+ """
+ Logs the Diplomacy game data and generates HTML visualizations using the get_log_info methods.
+
+ Args:
+ path (str): Base path to save the data.
+ agents_log_info (list): List of agent information dictionaries containing the get_log_info results.
+ env_log_info (dict): Environment information from its get_log_info method.
+ metrics_func (str, optional): Name of the function to calculate metrics.
+ metrics_func_args (dict, optional): Arguments for the metrics function.
+ """
+ # Create directory structure
+ os.makedirs(path, exist_ok=True)
+
+ # Save the environment log info
+ env_log_path = os.path.join(path, "env_log.json")
+ with open(env_log_path, "w") as f:
+ json.dump(env_log_info, f, indent=4, default=_json_serialize)
+
+ # Process each agent's log info
+ for agent_log in agents_log_info:
+ power_name = agent_log["power_name"]
+
+ # Define paths for raw data and statistics subfolders
+ power_path = os.path.join(path, power_name)
+ raw_data_path = os.path.join(power_path, "raw_data")
+ statistics_path = os.path.join(power_path, "statistics")
+
+ # Ensure directories exist
+ os.makedirs(raw_data_path, exist_ok=True)
+ os.makedirs(statistics_path, exist_ok=True)
+
+ # Determine the next available file number for raw data
+ raw_files = os.listdir(raw_data_path)
+ raw_numbers = [int(f.split('_')[-1].split('.')[0]) for f in raw_files if f.startswith("log_")]
+ next_raw_number = max(raw_numbers, default=0) + 1
+ raw_file = os.path.join(raw_data_path, f"log_{next_raw_number}.json")
+
+ # Save agent log info
+ with open(raw_file, "w") as f:
+ json.dump(agent_log, f, indent=4, default=_json_serialize)
+
+ # Log metrics if a metrics function is provided
+ if metrics_func:
+ metrics_files = os.listdir(statistics_path)
+ metrics_numbers = [int(f.split('_')[-1].split('.')[0]) for f in metrics_files if f.startswith("metrics_")]
+ next_metrics_number = max(metrics_numbers, default=0) + 1
+ metrics_file = os.path.join(statistics_path, f"metrics_{next_metrics_number}.json")
+
+ metrics = globals()[metrics_func](agent_log, info, **metrics_func_args)
+ with open(metrics_file, "w") as f:
+ json.dump(metrics, f, indent=4)
+
+ # Generate the HTML visualization
+ html_content = generate_diplomacy_html(agents_log_info, env_log_info)
+
+ # Ensure the html directory exists
+ html_path = os.path.join(path, "html")
+ os.makedirs(html_path, exist_ok=True)
+
+ # Determine the next available file number for HTML
+ html_files = os.listdir(html_path)
+ html_numbers = [int(f.split('_')[-1].split('.')[0]) for f in html_files if f.startswith("game_summary_")]
+ next_html_number = max(html_numbers, default=0) + 1
+ html_file = os.path.join(html_path, f"game_summary_{next_html_number}.html")
+
+ # Save the HTML content to a file
+ with open(html_file, "w") as f:
+ f.write(html_content)
+
+def generate_diplomacy_html(agent_infos, env_info):
+ """
+ Generate HTML visualization for a Diplomacy game.
+
+ Args:
+ agent_infos (list): List of agent information dictionaries from get_log_info.
+ env_info (dict): Environment information from get_log_info.
+
+ Returns:
+ str: HTML content for the game visualization.
+ """
+ # Extract game information
+ game_id = env_info.get("game_id", "Unknown")
+ phase = env_info.get("phase", "Unknown")
+ map_name = env_info.get("map_name", "standard")
+ is_game_done = env_info.get("is_game_done", False)
+ outcome = env_info.get("outcome", [])
+
+ centers = env_info.get("centers", {})
+ units = env_info.get("units", {})
+
+ # HTML head and style
+ html_content = """
+
+
+
+
+
+ Diplomacy Game {game_id}
+
+
+
+
+
Game Information
+
+
+
Game Details
+
Game ID: {game_id}
+
Phase: {phase}
+
Map: {map_name}
+
Status: {status}
+
+
+
Supply Centers
+
+ """.format(
+ game_id=game_id,
+ phase=phase,
+ map_name=map_name,
+ status="Completed" if is_game_done else "Active"
+ )
+
+ # Add supply center information
+ for power, power_centers in centers.items():
+ html_content += f"""
+
+ {power}: {len(power_centers)}
+
+ """
+
+ html_content += """
+
+
+
+ """
+
+ # Add outcome if game is done
+ if is_game_done and outcome:
+ winners = outcome[1:] if len(outcome) > 1 else ["Draw"]
+ html_content += f"""
+
+
Game Outcome
+
Winners: {', '.join(winners)}
+
+ """
+
+ html_content += """
+
+
+ """
+
+ # Add each power's information
+ for agent_log in agent_infos:
+ power_name = agent_log["power_name"]
+ power_class = power_name.lower()
+ orders = agent_log.get("orders", [])
+ message_history = agent_log.get("message_history", [])
+
+ html_content += f"""
+
+
{power_name}
+
+
+
Units
+
+ """
+
+ # Add units information
+ power_units = units.get(power_name, [])
+ for unit in power_units:
+ html_content += f"- {unit}
"
+
+ html_content += """
+
+
+
+
+
Final Orders
+
+ """
+
+ # Add orders
+ for order in orders:
+ html_content += f"- {order}
"
+
+ html_content += """
+
+
+ """
+
+ # Add message history
+ for message in message_history:
+ if isinstance(message, dict):
+ # Skip system messages or handle differently
+ if message.get("role") == "system":
+ continue
+
+ role = message.get("role", "unknown")
+ content = message.get("content", "")
+
+ role_class = "user" if role == "user" else "assistant"
+ role_display = "Environment" if role == "user" else f"LLM ({power_name})"
+
+ # Escape HTML characters in content
+ content = content.replace("<", "<").replace(">", ">").replace("\n", "
")
+
+ html_content += f"""
+
+
{role_display}
+
{content}
+
+ """
+ elif isinstance(message, str):
+ # Simple string messages (may be used in some implementations)
+ html_content += f"""
+
+ """
+
+ html_content += """
+
+ """
+
+ html_content += """
+
+
+
+ """
+
+ return html_content
+
+def _json_serialize(obj):
+ """
+ A helper function to convert non-JSON-serializable objects
+ (like OrderResult) into strings or dicts.
+ """
+ # Check for the specific object types you know are problematic
+ if obj.__class__.__name__ == "OrderResult":
+ # Return a string representation or a dict
+ return str(obj)
+
+ # Fallback: attempt to convert anything else to string
+ return str(obj)
\ No newline at end of file
diff --git a/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging_for_training.py b/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging_for_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src_code_for_reproducibility/markov_games/ipd/Ipd_hard_coded_agents.py b/src_code_for_reproducibility/markov_games/ipd/Ipd_hard_coded_agents.py
new file mode 100644
index 0000000000000000000000000000000000000000..a974bddc69c1a3002ce5d84aac868f59bb731900
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/ipd/Ipd_hard_coded_agents.py
@@ -0,0 +1,72 @@
+from dataclasses import dataclass
+from typing import Any, Tuple
+
+from mllm.markov_games.ipd.ipd_agent import IPDAgent
+from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn
+
+
+@dataclass
+class AlwaysCooperateIPDAgent(IPDAgent):
+ async def act(self, observation) -> Tuple[Any, AgentActLog]:
+ """
+ Always plays the cooperate action, ignoring observation.
+ Returns the configured cooperate_string so the simulation parses it as "C".
+ """
+
+ action = self.cooperate_string
+
+ # Log a minimal, structured chat turn for consistency with other agents
+ turn_text = f"Playing cooperate: {action}"
+ self.state.chat_history.append(
+ ChatTurn(
+ agent_id=self.agent_id,
+ role="assistant",
+ content=turn_text,
+ is_state_end=True,
+ )
+ )
+
+ act_log = AgentActLog(
+ chat_turns=[self.state.chat_history[-1]],
+ info=None,
+ )
+
+ # Advance internal counters similar to IPDAgent semantics
+ self.state.chat_counter = len(self.state.chat_history)
+ self.state.round_nb = observation.round_nb
+
+ return action, act_log
+
+
+@dataclass
+class AlwaysDefectIPDAgent(IPDAgent):
+ async def act(self, observation) -> Tuple[Any, AgentActLog]:
+ """
+ Always plays the defect action, ignoring observation.
+ Returns the configured defect_string so the simulation parses it as "D".
+ """
+
+ action = self.defect_string
+
+ # Log a minimal, structured chat turn for consistency with other agents
+ turn_text = f"Playing defect: {action}"
+ self.state.chat_history.append(
+ ChatTurn(
+ agent_id=self.agent_id,
+ role="assistant",
+ content=turn_text,
+ is_state_end=True,
+ )
+ )
+
+ act_log = AgentActLog(
+ chat_turns=[self.state.chat_history[-1]],
+ info=None,
+ )
+
+ # Advance internal counters similar to IPDAgent semantics
+ self.state.chat_counter = len(self.state.chat_history)
+ self.state.round_nb = observation.round_nb
+
+ return action, act_log
+
diff --git a/src_code_for_reproducibility/markov_games/ipd/__init__.py b/src_code_for_reproducibility/markov_games/ipd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7f2388f6380fd3a54a2c80d1f1f77ae1d1fd4c8
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/ipd/__init__.py
@@ -0,0 +1,7 @@
+from .Ipd_hard_coded_agents import AlwaysCooperateIPDAgent, AlwaysDefectIPDAgent
+
+__all__ = [
+ "AlwaysCooperateIPDAgent",
+ "AlwaysDefectIPDAgent",
+]
+
diff --git a/src_code_for_reproducibility/markov_games/ipd/__pycache__/Ipd_hard_coded_agents.cpython-312.pyc b/src_code_for_reproducibility/markov_games/ipd/__pycache__/Ipd_hard_coded_agents.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..354fbab18d57425054c5e83b8a8c467d967ddccc
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/ipd/__pycache__/Ipd_hard_coded_agents.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/ipd/__pycache__/__init__.cpython-312.pyc b/src_code_for_reproducibility/markov_games/ipd/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ad972161150c1335d840f5d86f3a662a30f5826
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/ipd/__pycache__/__init__.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_agent.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1154ff84f7e1316b3bb49a684c6e5c789276cb1
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_agent.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_simulation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c559efc356f340201bb70a3558c4b9abe6a5c7df
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_simulation.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_statistics.cpython-312.pyc b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_statistics.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0bf55aa8dc1407f9fa71222cef9dd1ef44f6be6d
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_statistics.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/ipd/ipd_agent.py b/src_code_for_reproducibility/markov_games/ipd/ipd_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a6f64b542dcc9ee7e114e617bde9cc1181ea301
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/ipd/ipd_agent.py
@@ -0,0 +1,115 @@
+import copy
+import json
+import random
+import re
+from collections.abc import Callable
+from copy import deepcopy
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from mllm.markov_games.agent import Agent
+from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn
+
+
+@dataclass
+class IPDAgentState:
+ """
+ TOWRITE
+ """
+
+ nb_retries: int
+ round_nb: int
+ chat_counter: int
+ chat_history: List[ChatTurn]
+
+
+@dataclass
+class IPDAgent(Agent):
+ seed: int
+ agent_id: str
+ agent_name: str
+ policy: Callable[[List[Dict]], str]
+ intro_prompt: str # Introduction prompt explaining the game rules
+ goal_prompt: str # Prompt explaining the agent's goal
+ strategy_prompt: str # Prompt suggesting a strategy to the agent
+ max_errors: int # Maximum number of errors allowed before default action
+ allow_reasoning: bool # Whether to allow reasoning in the response
+ max_reasoning_chars: int # Maximum number of characters for reasoning
+ cooperate_string: str # string parsed as playing cooperate by simulation
+ defect_string: str # string parsed as playing defect by simulation
+
+ def __post_init__(self):
+ self.state = IPDAgentState(
+ nb_retries=0, round_nb=0, chat_counter=0, chat_history=[]
+ )
+
+ async def act(self, observation) -> Tuple[Any, AgentActLog]:
+ """
+ TOWRITE
+ """
+
+ action = None
+ action_is_ready = False
+ round_nb = observation.round_nb
+
+ # If it's the first round, we need to send the intro prompt
+ if round_nb == 0 and self.state.chat_counter == 0:
+ self.state.chat_history.append(
+ ChatTurn(
+ agent_id=self.agent_id,
+ role="user",
+ content=self.intro_prompt,
+ is_state_end=True,
+ )
+ )
+
+ # If new round
+ if round_nb > self.state.round_nb:
+ coagent_action = observation.last_coagent_move
+ user_message = f"Last round, the other agent played {coagent_action}."
+ self.state.chat_history.append(
+ ChatTurn(
+ agent_id=self.agent_id,
+ role="user",
+ content=user_message,
+ is_state_end=True,
+ )
+ )
+
+ # If not new round, try to get valid action from policy
+ output_chat_turn: ChatTurn = await self.policy(
+ state=self.state.chat_history,
+ agent_id=self.agent_id,
+ regex=f"({self.cooperate_string}|{self.defect_string})",
+ )
+ self.state.chat_history.append(output_chat_turn)
+ action = output_chat_turn.content
+
+ agent_step_log = AgentActLog(
+ chat_turns=self.state.chat_history[self.state.chat_counter :], info=None
+ )
+ self.state.chat_counter = len(self.state.chat_history)
+ self.state.round_nb = round_nb
+
+ return action, agent_step_log
+
+ def get_safe_copy(self):
+ """
+ Return a safe copy of the agent.
+ """
+ agent_copy = copy.copy(self)
+ agent_copy.state = copy.deepcopy(self.state)
+ return agent_copy
+
+ def reset(self):
+ self.state = IPDAgentState()
+ raise NotImplementedError
+
+ def render(self):
+ pass
+
+ def close(self):
+ pass
+
+ def get_agent_info(self):
+ pass
diff --git a/src_code_for_reproducibility/markov_games/ipd/ipd_simulation.py b/src_code_for_reproducibility/markov_games/ipd/ipd_simulation.py
new file mode 100644
index 0000000000000000000000000000000000000000..238c7d319bd6e679284d2636aeffa194662a664b
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/ipd/ipd_simulation.py
@@ -0,0 +1,162 @@
+import copy
+import random
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+
+from mllm.markov_games.markov_game import Simulation
+from mllm.markov_games.rollout_tree import SimulationStepLog
+from mllm.utils.get_coagent_id import get_coagent_id
+
+
+@dataclass
+class IPDState:
+ """
+ State of the Iterated Prisoner's Dilemma game.
+ """
+
+ round_nb: int = 0
+ done: bool = False
+ last_moves: Dict[str, str] | None = None
+
+
+@dataclass
+class IPDObs:
+ """
+ Observation in Iterated Prisoner's Dilemma game.
+ """
+
+ round_nb: int
+ last_coagent_move: str | None
+
+
+class IPD(Simulation):
+ """
+ Iterated Prisoner's Dilemma simulation following the standard.
+
+ In each round of the game, two agents simultaneously choose to either cooperate (C) or defect (D).
+ The payoffs are as follows:
+ - If both cooperate: Both receive the "reward" (usually 3 points)
+ - If both defect: Both receive the "punishment" (usually 1 point)
+ - If one cooperates and one defects: The defector receives the "temptation" (usually 5 points)
+ and the cooperator receives the "sucker" payoff (usually 0 points)
+
+ The game is played for a specified number of rounds.
+ """
+
+ def __init__(
+ self,
+ agent_ids: List[str],
+ agent_names: List[str],
+ seed: int,
+ rounds_per_game: int,
+ reward: float, # Both cooperate
+ punishment: float, # Both defect
+ temptation: float, # Defector's reward when other cooperates
+ sucker: float, # Cooperator's reward when other defects
+ cooperate_actions: List[str],
+ defect_actions: List[str],
+ ):
+ self.agent_ids = agent_ids
+ self.agent_names = agent_names
+ self.seed = seed
+ self.rounds_per_game = rounds_per_game
+ self.reward = reward
+ self.punishment = punishment
+ self.temptation = temptation
+ self.sucker = sucker
+ self.cooperate_actions = cooperate_actions
+ self.defect_actions = defect_actions
+ self.state = IPDState()
+
+ def step(self, actions: Dict[str, str]) -> Tuple[bool, SimulationStepLog]:
+ """
+ Take a step in the environment using the provided actions.
+ Here, the observations are just the states of the game.
+
+ Args:
+ actions (dict): A dictionary where keys are agent identifiers and values are actions ('C' or 'D').
+
+ Returns:
+ observations (dict): A dictionary where keys are agent identifiers and values are observations.
+ done (bool): Whether the episode has ended.
+ info (dict): Additional information about the environment.
+ """
+
+ # Calculate rewards using payoff matrix
+ agent0_action = actions[self.agent_ids[0]]
+ agent1_action = actions[self.agent_ids[1]]
+
+ # Normalize actions to standard cooperate/defect/gibberish format
+ def normalize_action(action):
+ if action in self.cooperate_actions:
+ return "C"
+ elif action in self.defect_actions:
+ return "D"
+ else:
+ return "D"
+
+ norm_action0 = normalize_action(agent0_action)
+ norm_action1 = normalize_action(agent1_action)
+
+ payoffs = {
+ ("C", "C"): [self.reward, self.reward],
+ ("C", "D"): [self.sucker, self.temptation],
+ ("D", "C"): [self.temptation, self.sucker],
+ ("D", "D"): [self.punishment, self.punishment],
+ }
+
+ round_rewards = {
+ self.agent_ids[0]: payoffs[(norm_action0, norm_action1)][0],
+ self.agent_ids[1]: payoffs[(norm_action0, norm_action1)][1],
+ }
+
+ # Update game state
+ self.state.round_nb += 1
+ self.state.last_moves = copy.deepcopy(actions)
+ done = self.state.round_nb >= self.rounds_per_game
+ step_log = SimulationStepLog(
+ rewards=round_rewards,
+ info={
+ "actions": {
+ self.agent_ids[0]: norm_action0,
+ self.agent_ids[1]: norm_action1,
+ }
+ },
+ )
+
+ return done, step_log
+
+ def get_obs(self):
+ """Returns all agent observations in dict
+ Returns:
+ observations
+ """
+ observations = {}
+ for agent_id in self.agent_ids:
+ observations[agent_id] = self.get_obs_agent(agent_id)
+ return observations
+
+ def get_obs_agent(self, agent_id):
+ """Returns observation for agent_id"""
+ if self.state.last_moves != None:
+ other_id = get_coagent_id(self.agent_ids, agent_id)
+ last_coagent_move = self.state.last_moves[other_id]
+ else:
+ last_coagent_move = None
+ obs = IPDObs(round_nb=self.state.round_nb, last_coagent_move=last_coagent_move)
+ return obs
+
+ def reset(self):
+ """Returns initial observations and states"""
+ self.state = IPDState()
+ return self.get_obs()
+
+ def get_safe_copy(self):
+ """
+ Return a safe copy of the simulation.
+ """
+ simulation_copy = copy.copy(self)
+ simulation_copy.state = copy.deepcopy(self.state)
+ return simulation_copy
diff --git a/src_code_for_reproducibility/markov_games/ipd/ipd_statistics.py b/src_code_for_reproducibility/markov_games/ipd/ipd_statistics.py
new file mode 100644
index 0000000000000000000000000000000000000000..8740fda6bc2550c92aef27ed9fbe7bc945be42ca
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/ipd/ipd_statistics.py
@@ -0,0 +1,18 @@
+from __future__ import annotations
+
+from typing import Dict, Callable, List, Tuple
+
+from mllm.markov_games.rollout_tree import SimulationStepLog
+
+
+def avg_reward(sl: SimulationStepLog) -> List[Tuple[str, float]]:
+ for aid in sl.rewards.keys():
+ if "buffer" in str(aid) and "live" not in str(aid):
+ return None
+ # One value per agent at each step
+ rewards_dict = {f"reward-{aid}": float(v) for aid, v in (sl.rewards or {}).items()}
+ return [(key, value) for key, value in rewards_dict.items() if value is not None]
+
+stat_functs: list[Callable[[SimulationStepLog], List[Tuple[str, float]]]] = [
+ avg_reward,
+]
\ No newline at end of file
diff --git a/src_code_for_reproducibility/markov_games/markov_game.py b/src_code_for_reproducibility/markov_games/markov_game.py
new file mode 100644
index 0000000000000000000000000000000000000000..73a48213bddcf0a59976fa0870eec19f59ae47d9
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/markov_game.py
@@ -0,0 +1,208 @@
+"""
+This class unifies a simulation, and the agents acting in it (see `simulation.py` & `agent.py`).
+In a MarkovGame step,
+ 1) each agent takes an action,
+ 2) the state transitions with respect to these actions,
+ 3) all relevant data of the step is appended to the historical data list
+
+In order to perform 3), the agents and the simulation are expected, at each time step,
+to return a log of the state transition (from their perspective).
+For instance, the Simulation might send rewards and the agents might send prompting contexts to be used later to generate the training data.
+A different approach would be to simply have the agents keep their data private and log it upon completion of a trajectory.
+The approach we use here centralizes the data gathering aspect,
+making it easy to create sub-trajectories (in the `runners` defined in `runners.py`) descriptions that
+only log information for step transitions occuring after the branching out.
+"""
+import asyncio
+import copy
+import json
+import os
+from dataclasses import dataclass
+from typing import Any, List, Literal, Optional, Tuple
+
+from transformers.models.idefics2 import Idefics2Config
+
+from mllm.markov_games.agent import Agent
+from mllm.markov_games.rollout_tree import AgentActLog, StepLog
+from mllm.markov_games.simulation import Simulation
+
+AgentId = str
+
+
+@dataclass
+class AgentAndActionSafeCopy:
+ action: Any
+ action_info: AgentActLog
+ agent_after_action: type[Agent]
+
+
+class MarkovGame(object):
+ def __init__(
+ self,
+ id: int,
+ agents: dict[AgentId, type[Agent]],
+ simulation: type[Simulation],
+ crn_id: int,
+ ):
+ """
+ Args:
+ agents:
+ output_path:
+ Path where the step infos are saved.
+ simulation:
+ Simulation object. Example: IPDSimulation
+ """
+ self.agents = agents
+ self.agent_ids = self.agents.keys()
+ self.simulation = simulation
+ self.simulation_step_log = None
+ self.agent_step_logs = {agent_id: None for agent_id in self.agent_ids}
+ self.actions = {}
+ self.id = id
+ self.crn_id = crn_id
+
+ def get_id(self) -> str:
+ return self.id
+
+ def get_crn_id(self) -> int:
+ return self.crn_id
+
+ def get_agent_ids(self) -> List[AgentId]:
+ return list(self.agent_ids)
+
+ async def get_action_of_agent_without_side_effects(
+ self, agent_id: AgentId
+ ) -> Tuple[Any, AgentActLog]:
+ """
+ Safe function to get an action of an agent without modifying the agent or the simulation.
+ """
+ agent = self.agents[agent_id]
+ agent_before_action = agent.get_safe_copy()
+ obs = self.simulation.get_obs_agent(agent_id)
+ action, action_info = await agent.act(observation=obs)
+ self.agents[agent_id] = agent_before_action
+ agent_after_action = agent.get_safe_copy()
+ return AgentAndActionSafeCopy(action, action_info, agent_after_action)
+
+ async def get_actions_of_agents_without_side_effects(
+ self,
+ ) -> dict[AgentId, AgentAndActionSafeCopy]:
+ """
+ Safe function to get an action of an agent without modifying the agent or the simulation.
+ """
+ tasks = []
+ for agent_id in self.agent_ids:
+ task = asyncio.create_task(
+ self.get_action_of_agent_without_side_effects(agent_id)
+ )
+ tasks.append(task)
+ agent_and_action_safe_copies: list[
+ AgentAndActionSafeCopy
+ ] = await asyncio.gather(*tasks)
+ return {
+ agent_id: agent_and_action_safe_copy
+ for agent_id, agent_and_action_safe_copy in zip(
+ self.agent_ids, agent_and_action_safe_copies
+ )
+ }
+
+ def set_action_and_agent_after_action_manually(
+ self,
+ agent_id: AgentId,
+ agent_action_safe_copy: AgentAndActionSafeCopy,
+ ):
+ """
+ Set the action and the agent after action manually.
+ """
+ self.actions[agent_id] = agent_action_safe_copy.action
+ self.agent_step_logs[agent_id] = agent_action_safe_copy.action_info
+ self.agents[agent_id] = agent_action_safe_copy.agent_after_action
+
+ def set_actions_of_agents_manually(
+ self, actions: dict[AgentId, AgentAndActionSafeCopy]
+ ):
+ """
+ Set the actions of agents manually.
+ """
+ for agent_id, agent_action_safe_copy in actions.items():
+ self.set_action_and_agent_after_action_manually(
+ agent_id, agent_action_safe_copy
+ )
+
+ async def set_action_of_agent(self, agent_id: AgentId):
+ """
+ TOWRITE
+ """
+ agent = self.agents[agent_id]
+ obs = self.simulation.get_obs_agent(agent_id)
+ action, action_info = await agent.act(observation=obs)
+ self.actions[agent_id] = action
+ self.agent_step_logs[agent_id] = action_info
+
+ async def set_actions(self):
+ """
+ TOWRITE
+ """
+ # background_tasks = set()
+ tasks = []
+ for agent_id in self.agent_ids:
+ task = asyncio.create_task(self.set_action_of_agent(agent_id))
+ tasks.append(task)
+ await asyncio.gather(*tasks)
+
+ def take_simulation_step(self):
+ """
+ TOWRITE
+ """
+ terminated, self.simulation_step_log = self.simulation.step(self.actions)
+ return terminated
+
+ def get_step_log(self) -> StepLog:
+ """
+ TOWRITE
+ TODO: assert actions and simulation have taken step
+ """
+ step_log = StepLog(
+ simulation_step_log=self.simulation_step_log,
+ action_logs=self.agent_step_logs,
+ )
+ return step_log
+
+ async def step(self) -> Tuple[bool, StepLog]:
+ """
+ TOWRITE
+ """
+ await self.set_actions()
+ terminated = self.take_simulation_step()
+ step_log = self.get_step_log()
+ return terminated, step_log
+
+ def get_safe_copy(self):
+ """
+ TOWRITE
+ """
+
+ new_markov_game = copy.copy(self)
+ new_simulation = self.simulation.get_safe_copy()
+ new_agents = {
+ agent_id: agent.get_safe_copy() for agent_id, agent in self.agents.items()
+ }
+
+ # Reassign copied components
+ new_markov_game.simulation = new_simulation
+ new_markov_game.agents = new_agents
+
+ # IMPORTANT: ensure agent_ids references the new agents dict, not the original
+ new_markov_game.agent_ids = new_markov_game.agents.keys()
+
+ # Deep-copy step data to avoid correlation
+ new_markov_game.simulation_step_log = copy.deepcopy(self.simulation_step_log)
+ new_markov_game.actions = copy.deepcopy(self.actions)
+ # Rebuild logs to align exactly with new agent ids
+ old_agent_step_logs = copy.deepcopy(self.agent_step_logs)
+ new_markov_game.agent_step_logs = {
+ agent_id: old_agent_step_logs.get(agent_id)
+ for agent_id in new_markov_game.agent_ids
+ }
+
+ return new_markov_game
diff --git a/src_code_for_reproducibility/markov_games/negotiation/README.md b/src_code_for_reproducibility/markov_games/negotiation/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c8ebadee705971c5331924ed1b9d53c7e5f69770
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/negotiation/README.md
@@ -0,0 +1,40 @@
+## Negotiation Games: core mechanics and variants
+
+This family of games feature two agents who, in each round, may briefly communicate and then simultaneously propose how to split a fixed resource (most commonly 10 coins). Rewards are the amount kept multiplied by an agent’s per-unit value. The starting speaker alternates deterministically across rounds.
+
+Communication is optional and variant-dependent: some settings encourage rich messaging to share private information, while others remove messaging entirely to focus on allocation behavior.
+
+Proportional splitting is used when the two proposals exceed the available total: allocations are scaled proportionally rather than discarded. This preserves a useful learning signal even when agents over-claim.
+
+### Variants (in increasing difficulty)
+
+- No‑Press Split
+ - Single item type (coins)
+ - No communication; agents go straight to making split proposals, with the starting player alternating deterministically.
+ - Motivation: mirrors no‑communication setups (e.g., Advantage Alignment) while keeping the split decision nontrivial.
+ - Deterministic Mode: values are fixed and public: one agent values coins at 10, the other at 1 (alternates each round).
+ - Stochastic Mode: values are random and uncorrelated.
+
+- Trust-and-Split RPS (TAS-RPS)
+ - Single item type (coins)
+ - Each round, a rock–paper–scissors hand draw creates a strong asymmetry: the winner’s per-coin value is 10, the loser’s is 1.
+ - Each agent initially sees only their own hand and must communicate to coordinate an optimal split.
+ - Motivation: enforce large value disparity so one’s own value reveals little about the other’s (avoiding ceiling effects) and incentivize meaningful communication.
+
+- Trust-and-Split (TAS)
+ - Single item type (coins); each round, each agent’s per-coin value is independently sampled in a broad range (e.g., 1–20).
+ - Each agent observes only their own value; they may use short messages to share and negotiate.
+ - Motivation: a simple blend that tests whether agents learn to exchange private information and coordinate proportional, value-aware splits.
+
+- Deal-or-No-Deal (DOND)
+ - Introduced in [Deal or No Deal? End-to-End Learning for Negotiation Dialogues](https://arxiv.org/pdf/1706.05125)
+ - Multiple item types (typically "books", "hats" and "balls") with limited stocks; each agent has its own per-type values.
+ - A deal pays out only if both proposals exactly agree and respect the stock; otherwise no deal (zero reward) that round.
+ - Motivation: a known benchmark closer to real-world bargaining, where both parties must explicitly agree.
+
+
+
+
+
+
+
diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_agent.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c3cdd472992d134e05776b55eeeaea51ef72feb
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_agent.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_simulation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f30a76c150da8a74c1987b2d712c714b6b3bea49
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_simulation.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_agent.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5e0ddf5cac0cf4f5b9909e743fbaee6940fc4ab
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_agent.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_hard_coded_policies.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_hard_coded_policies.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d0e8f66077980f9dca7955a12d324c195ff89ce
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_hard_coded_policies.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_simulation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6cea4e60166d85502b70f5d5ab56eaa4707eb7e9
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_simulation.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/negotiation_statistics.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/negotiation_statistics.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ebb546838e912f12d62abfe2b22a71ce2adbd98e
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/negotiation_statistics.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_agent.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ad86d92929bd6c38a1aeba2fa3d8ad619cab45df
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_agent.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_simulation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c9443951c86d71f7d7150be5007aee9061220f2
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_simulation.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_agent.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c8c381a73774bad0361fbe12560a32c645ca13dc
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_agent.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_agent.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9228dcd110476851547e76bac7ba6f795c991a0a
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_agent.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_simulation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3e363d683aca4bb5e1f4ed73801043a87b64d50
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_simulation.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_agent.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..708db5895faf46b2e0eba53b1fe8567e7b17cde1
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_agent.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simulation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67cba7cc63da526ed267724c423a113da94355ea
Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simulation.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/markov_games/negotiation/dond_agent.py b/src_code_for_reproducibility/markov_games/negotiation/dond_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..958756f16e1ab0e5b348d2ae7a37777be2a0ad21
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/negotiation/dond_agent.py
@@ -0,0 +1,61 @@
+import copy
+import re
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any, Dict, List, Tuple
+
+from mllm.markov_games.agent import Agent
+from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn
+from mllm.markov_games.negotiation.dond_simulation import (
+ DealNoDealObs,
+)
+from mllm.markov_games.negotiation.nego_simulation import Split
+from mllm.markov_games.negotiation.nego_agent import NegotiationAgent, NegotiationAgentState
+
+class DealNoDealAgent(NegotiationAgent):
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.intro_prompt = (
+ "You are {agent_id}. You are playing an iterated game. "
+ "At each round, you and other agent will try to distribute among yourselves items of types {item_types}. "
+ "You only know how much you value each item type, but not the other agent's values. "
+ "You can communicate with the other agent by sending up to {quota_messages_per_agent_per_round} short messages per round. "
+ "Each round, after exchanging messages, you and the other agent will submit a private proposal. "
+ "A deal is accepted only if both proposals match exactly and are within stock; otherwise no deal (0 points for both at that round). "
+ "The values of the items of the other agent at the previous round are revealed to you after each round. "
+ "Your goal is: {goal}."
+ )
+ self.new_round_prompt = ("New round {round_nb}. Items: {stock}. Your values: {values}. ")
+ self.last_round_prompt = ("Last round, other agent's values: {previous_values_coagent}. ")
+ self.send_split_prompt = ("Respond with ... where you propose how many items of each type you want to keep.")
+
+ def get_message_regex(self, observation: DealNoDealObs) -> str:
+ return r"[\s\S]{0,400}"
+
+ def get_split_regex(self, observation: DealNoDealObs) -> str:
+ parts = []
+ for t in observation.item_types:
+ s = int(observation.quantities.get(t, 0))
+ allowed = "|".join(str(k) for k in range(0, s + 1))
+ rng = f"({allowed})"
+ parts.append(fr"<{t}>{rng}{t}>")
+ items_block = "".join(parts)
+ return fr"({items_block})"
+
+ def get_split_action(self, policy_output: str, observation: DealNoDealObs) -> Split:
+ import re as _re
+ allocations: Dict[str, int] = {}
+ for t in observation.item_types:
+ m = _re.search(fr"<{t}>([0-9]+){t}>", policy_output)
+ if m:
+ allocations[t] = int(m.group(1))
+ else:
+ allocations[t] = 0
+ return Split(items_given_to_self=allocations)
+
+
+
diff --git a/src_code_for_reproducibility/markov_games/negotiation/dond_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/dond_simulation.py
new file mode 100644
index 0000000000000000000000000000000000000000..a27d6ce2cf7e31a0cddd341db39ae7898b086115
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/negotiation/dond_simulation.py
@@ -0,0 +1,153 @@
+import copy
+from dataclasses import dataclass
+from typing import Any, Dict, List, Tuple
+
+from numpy.random import default_rng
+
+from mllm.markov_games.rollout_tree import SimulationStepLog
+from mllm.markov_games.negotiation.nego_simulation import Split, NegotiationState, NegotiationObs, NegotiationSimulation
+from mllm.utils.get_coagent_id import get_coagent_id
+
+
+AgentId = str
+
+
+@dataclass
+class DealNoDealState(NegotiationState):
+ item_types: List[str]
+ values: Dict[AgentId, Dict[str, int]]
+
+@dataclass
+class DealNoDealObs(NegotiationObs):
+ my_values: Dict[str, int]
+ item_types: List[str]
+ previous_values_coagent: Dict[str, int] | None
+
+
+def random_partition_integer(rng, total: int, parts: int) -> List[int]:
+ if parts <= 0:
+ return []
+ if total <= 0:
+ return [0 for _ in range(parts)]
+ cuts = sorted(rng.integers(0, total + 1, size=parts - 1).tolist())
+ vals = []
+ prev = 0
+ for c in cuts + [total]:
+ vals.append(c - prev)
+ prev = c
+ return vals
+
+class DealNoDealSimulation(NegotiationSimulation):
+
+ def __init__(
+ self,
+ item_types: List[str] = ["books", "hats", "balls"],
+ *args,
+ **kwargs,
+ ):
+ super().__init__(item_types=item_types, *args, **kwargs)
+ self.reset()
+
+ def _other(self, agent_id: AgentId) -> AgentId:
+ return get_coagent_id(self.agent_ids, agent_id)
+
+ def _sample_stock(self) -> Dict[str, int]:
+ # total items between 5 and 7
+ total_items = int(self.rng.integers(5, 8))
+ # nonnegative per-type counts summing to total_items
+ parts = random_partition_integer(self.rng, total_items, len(self.item_types))
+ # allow zeros per type
+ return {t: int(c) for t, c in zip(self.item_types, parts)}
+
+ def _sample_values_pair(self) -> Dict[AgentId, Dict[str, int]]:
+ # Each agent has integer non-negative values that sum to 10
+ # Each item type valued by at least one agent
+ # Some item type valued by both agents
+ while True:
+ vals_a = random_partition_integer(self.rng, 10, len(self.item_types))
+ vals_b = random_partition_integer(self.rng, 10, len(self.item_types))
+ a = {t: int(v) for t, v in zip(self.item_types, vals_a)}
+ b = {t: int(v) for t, v in zip(self.item_types, vals_b)}
+ # each item valued by at least one
+ ok1 = all((a[t] > 0) or (b[t] > 0) for t in self.item_types)
+ # some item valued by both
+ ok2 = any((a[t] > 0) and (b[t] > 0) for t in self.item_types)
+ if ok1 and ok2:
+ return {self.agent_ids[0]: a, self.agent_ids[1]: b}
+
+ def _is_valid_allocation(self, allocation: Dict[str, int], stock: Dict[str, int]) -> bool:
+ for t in self.item_types:
+ v = allocation.get(t)
+ if v is None:
+ return False
+ if not isinstance(v, int):
+ return False
+ if v < 0 or v > int(stock.get(t, 0)):
+ return False
+ return True
+
+ def set_new_round_of_variant(self):
+ # Keep same values, resample stock
+ self.state.quantities = self._sample_stock()
+
+ def get_info_of_variant(self, state: NegotiationState, actions: Dict[AgentId, Any]) -> Dict[str, Any]:
+ return {
+ "quantities": copy.deepcopy(state.quantities),
+ "values": copy.deepcopy(state.values),
+ 'splits': copy.deepcopy(state.splits),
+ }
+
+ def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]:
+ """
+ Returns the rewards for each agent.
+ """
+ split_a = splits[self.agent_ids[0]].items_given_to_self
+ split_b = splits[self.agent_ids[1]].items_given_to_self
+ rewards = {self.agent_ids[0]: 0, self.agent_ids[1]: 0}
+ for t in self.item_types:
+ # If not complementary, return 0!
+ if not split_a[t] + split_b[t] == self.state.quantities[t]:
+ return {self.agent_ids[0]: 0, self.agent_ids[1]: 0}
+ rewards[self.agent_ids[0]] += split_a[t] * self.state.values[self.agent_ids[0]][t]
+ rewards[self.agent_ids[1]] += split_b[t] * self.state.values[self.agent_ids[1]][t]
+ return rewards
+
+ def get_obs(self):
+ return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids}
+
+ def get_obs_agent(self, agent_id):
+ other_id = self._other(agent_id)
+ obs = DealNoDealObs(
+ round_nb=self.state.round_nb,
+ last_message=self.state.last_message,
+ current_agent=self.state.current_agent,
+ quantities=copy.deepcopy(self.state.quantities),
+ value=0.0, # unused in DOND
+ other_agent_split=None, # not meaningful until split
+ split_phase=self.state.split_phase,
+ quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round,
+ my_values=copy.deepcopy(self.state.values[agent_id]),
+ item_types=list(self.item_types),
+ previous_values_coagent=copy.deepcopy(self.state.values.get(other_id, {})),
+ )
+ return obs
+
+ def reset(self):
+ start_agent = self.agent_ids[self._starting_agent_index]
+ stock = self._sample_stock()
+ values = self._sample_values_pair()
+ self.state = DealNoDealState(
+ round_nb=0,
+ last_message="",
+ current_agent=start_agent,
+ quantities=stock,
+ values=values,
+ previous_values=None,
+ splits={aid: None for aid in self.agent_ids},
+ nb_messages_sent={aid: 0 for aid in self.agent_ids},
+ split_phase=False,
+ item_types=list(self.item_types),
+ )
+ return self.get_obs()
+
+
diff --git a/src_code_for_reproducibility/markov_games/negotiation/nego_agent.py b/src_code_for_reproducibility/markov_games/negotiation/nego_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b5bf4e3ca4ee7faa982360674e19d9eff6980dc
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/negotiation/nego_agent.py
@@ -0,0 +1,242 @@
+import copy
+from abc import abstractmethod
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any, Dict, List, Tuple
+
+import numpy as np
+
+from mllm.markov_games.agent import Agent
+from mllm.markov_games.negotiation.nego_simulation import Message, NegotiationObs, Split
+from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn
+
+
+@dataclass
+class NegotiationAgentState:
+ round_nb: int
+ nb_messages_sent_this_round: int
+ chat_counter: int
+ chat_history: List[ChatTurn]
+
+
+class NegotiationAgent(Agent):
+ def __init__(
+ self,
+ seed: int,
+ agent_id: str,
+ agent_name: str,
+ policy: Callable[[List[Dict]], str],
+ goal: str,
+ exploration_prompts: List[str] = [],
+ exploration_prompt_probs: List[float] = [],
+ ):
+ self.seed = seed
+ self.agent_id = agent_id
+ self.agent_name = agent_name
+ self.policy = policy
+ self.goal = goal
+ self.exploration_prompts_toggled = len(exploration_prompts) > 0
+ if self.exploration_prompts_toggled:
+ exploration_prompts = copy.deepcopy(exploration_prompts)
+ exploration_prompts.append(None)
+ self.exploration_prompts = exploration_prompts
+ self.exploration_prompt_probs = np.array(exploration_prompt_probs)
+ assert self.exploration_prompt_probs.sum() <= 1
+ assert np.all(self.exploration_prompt_probs >= 0)
+ self.exploration_prompt_probs = np.append(
+ self.exploration_prompt_probs, 1 - self.exploration_prompt_probs.sum()
+ )
+ self.state = NegotiationAgentState(
+ round_nb=0, nb_messages_sent_this_round=0, chat_counter=0, chat_history=[]
+ )
+
+ # Implemented in variants
+ self.intro_prompt = ""
+ self.new_round_prompt = ""
+ self.last_round_prompt = ""
+ self.send_split_prompt = ""
+ self.wait_for_message_prompt = ""
+ self.last_message_prompt = ""
+ self.send_message_prompt = ""
+
+ @abstractmethod
+ def get_message_regex(self, observation: NegotiationObs) -> str:
+ pass
+
+ @abstractmethod
+ def get_split_regex(self, observation: NegotiationObs) -> str:
+ pass
+
+ @abstractmethod
+ def get_split_action(
+ self, policy_output: str, observation: NegotiationObs
+ ) -> Split:
+ pass
+
+ async def act(self, observation: NegotiationObs) -> Tuple[Any, AgentActLog]:
+ def dict_to_str(d: dict) -> str:
+ return ", ".join(f"{v} {k}" for k, v in d.items())
+
+ def dict_to_eq_str(d: dict) -> str:
+ return ", ".join(f"{k}={v}" for k, v in d.items())
+
+ is_our_turn = observation.current_agent == self.agent_id
+ action: Any = None
+ round_nb = observation.round_nb
+
+ prompt_parts: List[str] = []
+ obs_ctx = vars(observation)
+ obs_ctx_formmated = obs_ctx.copy()
+ for key in obs_ctx_formmated:
+ if isinstance(obs_ctx_formmated[key], dict) and "value" not in key:
+ obs_ctx_formmated[key] = dict_to_str(obs_ctx_formmated[key])
+ elif isinstance(obs_ctx_formmated[key], dict) and "value" in key:
+ obs_ctx_formmated[key] = dict_to_eq_str(obs_ctx_formmated[key])
+
+ #######################################
+ # build user prompt
+ #######################################
+
+ # First-ever call
+ is_intro = round_nb == 0 and self.state.chat_counter == 0
+ if is_intro:
+ prompt_parts.append(
+ self.intro_prompt.format(
+ goal=self.goal, agent=self.agent_name, **obs_ctx_formmated
+ )
+ )
+
+ # New round
+ is_new_round = round_nb > self.state.round_nb
+ if is_new_round or is_intro:
+ self.state.nb_messages_sent_this_round = 0
+ if not is_intro:
+ prompt_parts.append(self.last_round_prompt.format(**obs_ctx_formmated))
+ prompt_parts.append(self.new_round_prompt.format(**obs_ctx_formmated))
+ if self.exploration_prompts_toggled:
+ exploration_prompt = self.exploration_prompts[
+ np.random.choice(
+ len(self.exploration_prompts), p=self.exploration_prompt_probs
+ )
+ ]
+ if exploration_prompt is not None:
+ prompt_parts.append(exploration_prompt)
+ self.state.round_nb = round_nb
+
+ # Wait for message
+ if not is_our_turn and not observation.split_phase:
+ prompt_parts.append(
+ self.wait_for_message_prompt.format(**obs_ctx_formmated)
+ )
+
+ # Get last message
+ if is_our_turn and not is_new_round and not is_intro:
+ prompt_parts.append(self.last_message_prompt.format(**obs_ctx_formmated))
+
+ # Prompt to send message
+ must_send_message = not observation.split_phase and is_our_turn
+ if must_send_message:
+ prompt_parts.append(self.send_message_prompt.format(**obs_ctx_formmated))
+
+ # Prompt to give split
+ must_send_split = not must_send_message and observation.split_phase
+ if must_send_split:
+ var_names = ["x", "y", "z", "w"] # Extend as needed
+ items_str = ", ".join(
+ [
+ f"{var_names[i]} {item}"
+ for i, item in enumerate(obs_ctx["quantities"].keys())
+ ]
+ )
+ ranges_str = ", ".join(
+ [
+ f"{var_names[i]}: 0-{obs_ctx['quantities'][item]} (integer)"
+ for i, item in enumerate(obs_ctx["quantities"].keys())
+ ]
+ )
+ proposal_style = f"Proposal: {items_str} where {ranges_str}."
+ proposal_style2 = (
+ f" {items_str} where {ranges_str}."
+ )
+ prompt_parts.append(
+ self.send_split_prompt.format(
+ proposal_style=proposal_style,
+ proposal_style2=proposal_style2,
+ **obs_ctx_formmated,
+ )
+ )
+
+ # Append one ChatTurn with is_state_end=True
+ user_prompt = "\n".join(prompt_parts)
+ self.state.chat_history.append(
+ ChatTurn(
+ agent_id=self.agent_id,
+ role="user",
+ content=user_prompt,
+ is_state_end=True,
+ )
+ )
+
+ #######################################
+ # Get policy action
+ #######################################
+
+ # Query policy for the appropriate format
+ if must_send_message:
+ return_regex = self.get_message_regex(observation)
+ policy_output = await self.policy(
+ state=self.state.chat_history,
+ agent_id=self.agent_id,
+ regex=return_regex,
+ )
+ self.state.chat_history.append(
+ ChatTurn(
+ agent_id=self.agent_id,
+ role="assistant",
+ content=policy_output.content,
+ reasoning_content=policy_output.reasoning_content,
+ log_probs=policy_output.log_probs,
+ out_token_ids=policy_output.out_token_ids,
+ is_state_end=False,
+ )
+ )
+ action = Message(message=policy_output.content)
+ self.state.nb_messages_sent_this_round += 1
+
+ elif must_send_split:
+ return_regex = self.get_split_regex(observation)
+ policy_output = await self.policy(
+ state=self.state.chat_history,
+ agent_id=self.agent_id,
+ regex=return_regex,
+ )
+ self.state.chat_history.append(
+ ChatTurn(
+ agent_id=self.agent_id,
+ role="assistant",
+ content=policy_output.content,
+ reasoning_content=policy_output.reasoning_content,
+ log_probs=policy_output.log_probs,
+ out_token_ids=policy_output.out_token_ids,
+ is_state_end=False,
+ )
+ )
+ action = self.get_split_action(policy_output.content, observation)
+ else:
+ action = None
+
+ agent_step_log = AgentActLog(
+ chat_turns=self.state.chat_history[self.state.chat_counter :], info=None
+ )
+ self.state.chat_counter = len(self.state.chat_history)
+ return action, agent_step_log
+
+ def get_safe_copy(self):
+ agent_copy = copy.copy(self)
+ agent_copy.state = copy.deepcopy(self.state)
+ return agent_copy
+
+ def reset(self):
+ self.state = NegotiationAgentState(
+ round_nb=0, nb_messages_sent_this_round=0, chat_counter=0, chat_history=[]
+ )
diff --git a/src_code_for_reproducibility/markov_games/negotiation/nego_hard_coded_policies.py b/src_code_for_reproducibility/markov_games/negotiation/nego_hard_coded_policies.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b5c191e15ef6b0abada72b1b6ba3a4c59421fdf
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/negotiation/nego_hard_coded_policies.py
@@ -0,0 +1,64 @@
+import asyncio
+from typing import Optional
+from mllm.markov_games.negotiation.nego_agent import NegotiationAgent
+from mllm.markov_games.negotiation.no_press_nego_agent import NoPressAgent
+from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressObs
+from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn
+from mllm.markov_games.negotiation.nego_simulation import Split
+from typing import Any, Tuple
+
+class HardCodedNegoWelfareMaximizingPolicy(NoPressAgent):
+ async def act(self, observation: NoPressObs) -> Tuple[Any, AgentActLog]:
+ """
+ Policy that gives all of the items to the agent who values them more.
+ If the items are equally valued, give them to the agent who values them more.
+ """
+ quantities = observation.quantities
+ my_values = observation.value
+ other_values = observation.other_value
+
+ items_given_to_self = {}
+ for item, qty in quantities.items():
+ my_v = float(my_values.get(item, 0))
+ other_v = float(other_values.get(item, 0))
+ if my_v == other_v:
+ items_given_to_self[item] = int(qty) / 2
+ else:
+ items_given_to_self[item] = int(qty if my_v > other_v else 0)
+
+ action = Split(items_given_to_self=items_given_to_self)
+ act_log = AgentActLog(
+ chat_turns=[
+ ChatTurn(
+ agent_id=self.agent_id,
+ role="assistant",
+ content="Using welfare-maximizing split (all to higher-value agent).",
+ is_state_end=True,
+ )
+ ],
+ info=None,
+ )
+ return action, act_log
+
+class HardCodedNegoGreedyPolicy(NoPressAgent):
+ async def act(self, observation: NoPressObs) -> Tuple[Any, AgentActLog]:
+ """
+ Always gives itself all of the items.
+ """
+ quantities = observation.quantities
+ items_given_to_self = {item: int(qty) for item, qty in quantities.items()}
+
+ action = Split(items_given_to_self=items_given_to_self)
+ act_log = AgentActLog(
+ chat_turns=[
+ ChatTurn(
+ agent_id=self.agent_id,
+ role="assistant",
+ content="Using greedy split (keep all items).",
+ is_state_end=True,
+ )
+ ],
+ info=None,
+ )
+ return action, act_log
+
diff --git a/src_code_for_reproducibility/markov_games/negotiation/nego_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/nego_simulation.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a4d18532ab0472cd8f83414d52cf6df589fe126
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/negotiation/nego_simulation.py
@@ -0,0 +1,241 @@
+"""
+Negotiation simulation environment
+other agent is set at the start of every round. Even though current agent changes over message turns in a round.
+"""
+import copy
+from abc import abstractmethod
+from dataclasses import dataclass
+from typing import Any, Dict, List, Tuple
+
+from numpy.random import default_rng
+
+from mllm.markov_games.rollout_tree import SimulationStepLog
+from mllm.markov_games.simulation import Simulation
+from mllm.utils.get_coagent_id import get_coagent_id
+
+AgentId = str
+
+
+@dataclass
+class Split:
+ items_given_to_self: Dict[str, int]
+
+
+@dataclass
+class Message:
+ message: str
+
+
+@dataclass # gets extended by variants
+class NegotiationState:
+ round_nb: int
+ last_message: str
+ current_agent: AgentId
+ quantities: Dict[str, int]
+ values: Dict[AgentId, Dict[str, float]]
+ splits: Dict[AgentId, Split | None]
+ nb_messages_sent: Dict[AgentId, int]
+ previous_values: Dict[AgentId, Dict[str, float]] | None
+ previous_splits: Dict[AgentId, Dict[str, int] | None] | None
+ previous_points: Dict[AgentId, float] | None
+ previous_quantities: Dict[str, int] | None
+ split_phase: bool
+
+
+@dataclass # gets extended by variants
+class NegotiationObs:
+ round_nb: int
+ last_message: str
+ quota_messages_per_agent_per_round: int
+ current_agent: AgentId
+ other_agent: str
+ quantities: Dict[str, int]
+ item_types: List[str]
+ value: Dict[str, int]
+ split_phase: bool
+ last_split_agent: Dict[str, int] | None
+ last_value_agent: Dict[str, int] | None
+ last_points_agent: float | None
+ last_split_coagent: Dict[str, int] | None
+ last_value_coagent: Dict[str, int] | None
+ last_points_coagent: float | None
+ last_quantities: Dict[str, int] | None
+
+
+def compute_tas_style_rewards(
+ agent_ids: List[AgentId],
+ values: Dict[AgentId, float],
+ splits: Dict[AgentId, Split],
+ quantities: Dict[str, int],
+) -> Dict[AgentId, float]:
+ """
+ TAS-like reward computation: if sum of proposed coins exceeds max_coins,
+ allocate proportionally. Otherwise, use proposed amounts directly.
+ Rewards are quantity_kept * per-coin value for each agent.
+ """
+ a0, a1 = agent_ids[0], agent_ids[1]
+ r0, r1 = 0.0, 0.0
+
+ for item in quantities:
+ max_item = quantities[item]
+ item_to_self_0 = int(
+ (splits[a0].items_given_to_self.get(item, 0))
+ if splits[a0] is not None
+ else 0
+ )
+ item_to_self_1 = int(
+ (splits[a1].items_given_to_self.get(item, 0))
+ if splits[a1] is not None
+ else 0
+ )
+ denom = max(int(max_item), item_to_self_0 + item_to_self_1)
+ q0 = float(max_item) * float(item_to_self_0) / float(denom)
+ q1 = float(max_item) * float(item_to_self_1) / float(denom)
+ if type(values[a0]) is not dict:
+ r0 += q0 * float(values[a0])
+ r1 += q1 * float(values[a1])
+ else:
+ r0 += q0 * float(values[a0][item])
+ r1 += q1 * float(values[a1][item])
+ return {a0: r0, a1: r1}
+
+
+class NegotiationSimulation(Simulation):
+ def __init__(
+ self,
+ agent_ids: List[AgentId],
+ agent_names: List[str],
+ seed: int,
+ nb_of_rounds: int,
+ quota_messages_per_agent_per_round: int,
+ item_types: List[str] | None = None,
+ ):
+ self.seed = seed
+ self.rng = default_rng(self.seed)
+ self.agent_ids = list(agent_ids)
+ self.agent_names = agent_names
+ self.agent_id_to_name = {
+ agent_id: agent_name for agent_id, agent_name in zip(agent_ids, agent_names)
+ }
+ self.nb_of_rounds = int(nb_of_rounds)
+ self.quota_messages_per_agent_per_round = int(
+ quota_messages_per_agent_per_round
+ )
+ if item_types is not None:
+ self.item_types = [item.lower() for item in item_types]
+ else:
+ self.item_types = ["coins"]
+ self.state: NegotiationState | None = None
+ self._starting_agent_index = self.rng.choice([0, 1])
+ self.reset()
+
+ def _other(self, agent_id: AgentId) -> AgentId:
+ return get_coagent_id(self.agent_ids, agent_id)
+
+ @abstractmethod
+ def set_new_round_of_variant(self):
+ pass
+
+ @abstractmethod
+ def get_info_of_variant(
+ self, state: NegotiationState, actions: Dict[AgentId, Any]
+ ) -> Dict[str, Any]:
+ pass
+
+ def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]:
+ """
+ Returns terminated, step_log
+ """
+ assert self.state is not None
+ current_agent = self.state.current_agent
+ a0, a1 = self.agent_ids[0], self.agent_ids[1]
+ action = actions.get(current_agent)
+
+ # Split phase: require both splits in the same timestep
+ if self.state.split_phase:
+ action_a0 = actions.get(a0)
+ action_a1 = actions.get(a1)
+ have_both_splits = isinstance(action_a0, Split) and isinstance(
+ action_a1, Split
+ )
+ if not have_both_splits:
+ rewards = {agent_id: 0.0 for agent_id in self.agent_ids}
+ return False, SimulationStepLog(
+ rewards=rewards, info={"type": "waiting_for_splits"}
+ )
+
+ # Record splits
+ self.state.splits[a0] = action_a0
+ self.state.splits[a1] = action_a1
+
+ # Compute rewards and end round
+ rewards = self.get_rewards(self.state.splits)
+
+ # Info
+ info = self.get_info_of_variant(self.state, actions)
+
+ # Prepare next round
+ # Alternate starting agent
+ self.state.round_nb += 1
+ self._starting_agent_index = 1 - self._starting_agent_index
+ self.state.current_agent = self.agent_ids[self._starting_agent_index]
+ self.state.previous_values = copy.deepcopy(self.state.values)
+ self.state.previous_splits = copy.deepcopy(self.state.splits)
+ self.state.previous_quantities = copy.deepcopy(self.state.quantities)
+ self.state.previous_points = copy.deepcopy(rewards)
+ self.state.last_message = ""
+ self.set_new_round_of_variant() # variant specific
+ self.state.splits = {agent_id: None for agent_id in self.agent_ids}
+ self.state.nb_messages_sent = {agent_id: 0 for agent_id in self.agent_ids}
+ is_last_timestep_in_round = True
+ done = self.state.round_nb >= self.nb_of_rounds
+
+ # Message phase
+ elif isinstance(action, Message):
+ self.state.last_message = action.message
+ self.state.nb_messages_sent[current_agent] += 1
+
+ # Move turn to other agent
+ self.state.current_agent = self._other(current_agent)
+
+ # If both agents have reached their message quota, enter split phase
+ if all(
+ self.state.nb_messages_sent[agent_id]
+ >= self.quota_messages_per_agent_per_round
+ for agent_id in self.agent_ids
+ ):
+ self.state.split_phase = True
+ is_last_timestep_in_round = False
+ done = False
+ rewards = {agent_id: 0.0 for agent_id in self.agent_ids}
+ info = {"type": "message"}
+
+ info[
+ "is_last_timestep_in_round"
+ ] = is_last_timestep_in_round # Used later to group round timesteps if needed
+ return done, SimulationStepLog(rewards=rewards, info=info)
+
+ def get_obs(self):
+ """Returns all agent observations in dict"""
+ return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids}
+
+ @abstractmethod
+ def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]:
+ pass
+
+ @abstractmethod
+ def get_obs_agent(self, agent_id):
+ pass
+
+ def get_state(self):
+ return self.state
+
+ def get_safe_copy(self):
+ """Return a safe copy of the simulation."""
+ simulation_copy = copy.copy(self)
+ simulation_copy.state = copy.deepcopy(self.state)
+ return simulation_copy
+
+ @abstractmethod
+ def reset(self) -> dict[AgentId, NegotiationObs]:
+ pass
diff --git a/src_code_for_reproducibility/markov_games/negotiation/negotiation_statistics.py b/src_code_for_reproducibility/markov_games/negotiation/negotiation_statistics.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccc7d53357033710b8409bdf2bfafafa58a40826
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/negotiation/negotiation_statistics.py
@@ -0,0 +1,244 @@
+from __future__ import annotations
+
+from typing import Callable, Dict, List, Tuple
+
+from mllm.markov_games.negotiation.nego_simulation import Split
+from mllm.markov_games.rollout_tree import SimulationStepLog
+
+
+def avg_reward(sl: SimulationStepLog) -> List[Tuple[str, float]]:
+ """Average (per-step) reward for each agent and overall.
+
+ What it computes:
+ - Returns the raw reward for every (non-buffer) agent at the current
+ simulation step.
+ - Adds an aggregate key ``all_agents`` which is the simple arithmetic
+ mean across the agents present in ``sl.rewards``.
+
+ Rationale / motivation:
+ Monitoring the reward stream at each step helps:
+ * Diagnose reward shaping issues (e.g., unintended negative drift).
+ * Provide a fairness snapshot (are rewards systematically skewed?).
+ * Supply a ubiquitous baseline metric used by other higher‑level
+ summaries (efficiency, surplus allocation, etc.).
+
+ Return shape:
+ { agent_id: float, ..., "all_agents": float }
+ If any agent id contains the substring "buffer" we treat this step as
+ an implementation artifact (e.g., rollout buffer) and return ``None``
+ to avoid polluting aggregates.
+ """
+ for aid in sl.rewards.keys():
+ if "buffer" in str(aid) and "live" not in str(aid):
+ return None
+ # One value per agent at each step
+ rewards_dict = {f"reward-{aid}": float(v) for aid, v in (sl.rewards or {}).items()}
+ return [(key, value) for key, value in rewards_dict.items() if value is not None]
+
+
+def split_efficiency(sl: SimulationStepLog) -> List[Tuple[str, float]] | None:
+ """Final‑round allocation efficiency relative to an upper bound.
+
+ What it computes (only on the last timestep of a negotiation round):
+ - Uses ``info['values']`` (per‑agent per‑item valuations) and
+ ``info['quantities']`` (available item counts) to form a greedy
+ *upper bound* on achievable total reward: allocate each unit of an
+ item to the single agent who values that item most.
+ - Compares the actually realized sum of rewards at that final
+ timestep to this constructed maximum.
+ - Emits a single scalar under key ``"all_agents"`` equal to
+ achieved / theoretical_max.
+
+ Motivation:
+ Efficiency (a core welfare notion) distinguishes between coordination
+ failures (low efficiency) versus strategic distributional disputes
+ (high efficiency but uneven splits). Tracking this per round helps
+ evaluate whether models learn to identify and realize joint surplus.
+
+ Notes / caveats:
+ - Only defined for 2+ non‑buffer agents; if a buffer agent is present
+ returns ``None`` to exclude spurious steps.
+ - Requires the environment to have populated ``values`` and
+ ``quantities``; otherwise returns ``None``.
+ - This is an optimistic bound (not necessarily reachable under
+ protocol constraints) but is simple, fast, and comparable across
+ runs.
+ """
+ info = sl.info or {}
+ if not info or not info.get("is_last_timestep_in_round"):
+ return None
+ quantities = info.get("quantities") or {}
+ values = info.get("values") or {}
+ if not values or not quantities:
+ return None
+ agent_ids = list(sl.rewards.keys())
+ if type(values[agent_ids[0]]) is dict:
+ item_keys = list(values.values())[0].keys()
+ max_vals, max_quantities = [], []
+ for item in item_keys:
+ max_val = max(float(agent_vals[item]) for agent_vals in values.values())
+ max_vals.append(max_val)
+ max_quantities.append(quantities[item])
+ else:
+ max_vals = [max(float(v) for v in values.values())]
+ max_quantities = [quantities[item] for item in quantities.keys()]
+ for aid in sl.rewards.keys():
+ if "buffer" in str(aid) and "live" not in str(aid):
+ return None
+ achieved = sum(float(v) for v in sl.rewards.values())
+ max_reward = sum(d * v for d, v in zip(max_quantities, max_vals))
+ # Efficiency is a global metric; emit same value for a special key "all"
+ return [("split_efficiency", achieved / max_reward)]
+
+
+def _extract_items_from_split(raw_split: Dict) -> Dict[str, float] | None:
+ """Return a mapping item->proposal amount from a split structure.
+
+ Supports both generic negotiation splits with nested structure
+ { 'items_given_to_self': {item: qty, ...}}
+ and TAS coin-only variants which may already be a flat mapping {'coins': qty}.
+ """
+
+ if raw_split is None:
+ return {}
+ elif isinstance(raw_split, Split):
+ return {k: float(v) for k, v in raw_split.items_given_to_self.items()}
+ elif isinstance(raw_split, dict):
+ if "items_given_to_self" in raw_split and isinstance(
+ raw_split["items_given_to_self"], dict
+ ):
+ return {k: float(v) for k, v in raw_split["items_given_to_self"].items()}
+ # Fallback: assume already flat mapping of items
+ elif hasattr(raw_split, "items_given_to_self"):
+ return {k: float(v) for k, v in raw_split["items_given_to_self"].items()}
+ return {
+ k: float(v) for k, v in raw_split.items() if isinstance(v, (int, float))
+ }
+ return {}
+
+
+def _average_proposal_relative_value(
+ sl: SimulationStepLog,
+ metric_name: str,
+ comparator: Callable[[float, float], bool],
+ opposite_comparator: Callable[[float, float], bool],
+) -> Dict[str, float | None] | None:
+ """Shared implementation for proposal size conditioned on relative value.
+
+ Parameters:
+ comparator: returns True when agent_0's value relation (e.g. < or >)
+ to agent_1 holds for an item and we should collect agent_0's
+ proposed quantity for that item.
+ opposite_comparator: inverse relation used to collect agent_1's items.
+
+ Behavior:
+ - Executes only on final timestep of a round (where the definitive
+ proposal / allocation is known via ``info['splits']``).
+ - For each item, classifies which agent's value satisfies the chosen
+ relation and records that agent's proposed quantity from the split.
+ - Averages (mean) across all qualifying items per agent; if no items
+ qualify for an agent returns ``None`` for that agent id.
+ - Adds ``all_agents`` mean across the numeric (non-None) agent values.
+
+ Why this matters:
+ Distinguishing how much an agent *asks for* when it subjectively
+ values items more (or less) than its counterpart reveals patterns of
+ opportunism vs. concession. This is especially useful when raw reward
+ differences are subtle but allocation *intent* differs.
+ """
+ info = sl.info or {}
+ if not info or not info.get("is_last_timestep_in_round"):
+ return None
+ quantities = info.get("quantities") or {}
+ splits = info.get("splits") or {}
+ values = info.get("values") or {}
+ agent_ids: List[str] = list(sl.rewards.keys())
+ if len(agent_ids) != 2:
+ return None # Only defined for 2-agent case.
+ for aid in agent_ids:
+ if "buffer" in str(aid) and "live" not in str(aid):
+ return None
+ # Extract per-agent item proposals robustly
+ split_items = {aid: _extract_items_from_split(splits.get(aid)) for aid in agent_ids}
+ agent_0_vals: List[float] = []
+ agent_1_vals: List[float] = []
+ for item in quantities.keys():
+ # Values may be either a float (same for all items) or dict per item
+ v0_raw = values[agent_ids[0]]
+ v1_raw = values[agent_ids[1]]
+ v0 = float(v0_raw[item]) if isinstance(v0_raw, dict) else float(v0_raw)
+ v1 = float(v1_raw[item]) if isinstance(v1_raw, dict) else float(v1_raw)
+ if comparator(v0, v1):
+ agent_0_vals.append(split_items[agent_ids[0]].get(item, 0.0))
+ elif opposite_comparator(v0, v1):
+ agent_1_vals.append(split_items[agent_ids[1]].get(item, 0.0))
+ out: Dict[str, float | None] = {}
+ out[f"{metric_name}-{agent_ids[0]}"] = (
+ sum(agent_0_vals) / len(agent_0_vals) if agent_0_vals else None
+ )
+ out[f"{metric_name}-{agent_ids[1]}"] = (
+ sum(agent_1_vals) / len(agent_1_vals) if agent_1_vals else None
+ )
+
+ return [(key, value) for key, value in out.items() if value is not None]
+
+
+def average_proposal_when_agent_values_item_lower(
+ sl: SimulationStepLog,
+) -> List[Tuple[str, float | None]] | None:
+ """Mean quantity an agent proposes for items it values *less* than opponent.
+
+ Interpretation:
+ A higher value implies the agent still claims (or is allocated) a
+ notable share of items where it has a comparative *disadvantage* in
+ valuation, signaling either strategic over-claiming or protocol-driven
+ egalitarian splits. Conversely, very low numbers can indicate
+ efficient specialization or excessive concession.
+
+ Returns:
+ Mapping { agent_id: float | None, "all_agents": float | None } where
+ None indicates no qualifying items for that agent in the round.
+ """
+ return _average_proposal_relative_value(
+ sl,
+ "average_proposal_when_agent_values_item_lower",
+ lambda a, b: a < b,
+ lambda a, b: a > b,
+ )
+
+
+def average_proposal_when_agent_values_item_higher(
+ sl: SimulationStepLog,
+) -> List[Tuple[str, float | None]] | None:
+ """Mean quantity an agent proposes for items it values *more* than opponent.
+
+ Interpretation:
+ Captures how aggressively an agent claims items where it holds a
+ comparative *advantage*. Elevated values can reflect rational
+ specialization (efficient exploitation of comparative advantage) or
+ potentially unfair grabs if paired with low concession in the lower
+ valuation metric. Comparing this with the 'lower' counterpart helps
+ profile negotiation style (cooperative vs. exploitative).
+
+ Returns:
+ Mapping { agent_id: float | None, "all_agents": float | None } where
+ None indicates no qualifying items.
+ """
+ return _average_proposal_relative_value(
+ sl,
+ "average_proposal_when_agent_values_item_higher",
+ lambda a, b: a > b,
+ lambda a, b: a < b,
+ )
+
+
+# Explicit list of metric functions exported for rendering. Helper functions
+# starting with '_' are intentionally excluded. Update this list when adding
+# new public statistics so render.py can rely on it instead of introspecting
+# every callable in the module.
+stat_functs: list[Callable[[SimulationStepLog], List[Tuple[str, float]]]] = [
+ avg_reward,
+ average_proposal_when_agent_values_item_lower,
+ average_proposal_when_agent_values_item_higher,
+ split_efficiency,
+]
diff --git a/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_agent.py b/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0a62f2465ac8b34dc09cbc003dcc663b170ffe7
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_agent.py
@@ -0,0 +1,94 @@
+from typing import Any, Dict, List, Tuple
+
+from mllm.markov_games.negotiation.nego_agent import (
+ NegotiationAgent,
+ NegotiationAgentState,
+)
+from mllm.markov_games.negotiation.nego_simulation import Split
+from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressObs
+from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn
+
+
+class NoPressAgent(NegotiationAgent):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # No communication in this variant
+ self.intro_prompt = (
+ "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n"
+ "Setup:\n"
+ "1. The game consists of multiple independent rounds.\n"
+ "2. In each round, there are multiple items to split between the two agents.\n"
+ "3. Both agents are assigned a per-item value between 1 and 20 (inclusive) in each round.\n"
+ "4. You can observe per-item values of both agents.\n"
+ "5. Because assignments are random, both agents are equally likely to have same expected per-item value.\n"
+ "\n"
+ "Protocol:\n"
+ "1. Both agents simultaneously propose the amount of each item they will keep.\n"
+ "2. If the total sum of proposals is less than or equal to the item quantity, both agents receive their proposed amounts.\n"
+ "3. If the total sum of proposals exceeds the item quantity, they are allocated proportionally.\n"
+ "4. Your points for the round = (amount you receive per item) x (your per-item value for that round), added across all items.\n"
+ "5. Points are accumulated across rounds.\n"
+ "Your goal: {goal}\n"
+ )
+ self.new_round_prompt = (
+ "A New Round Begins\n"
+ "The items to split are {quantities}.\n"
+ "Your per-item values are {value} and {other_agent}'s per-item values are {other_value}."
+ )
+ self.last_round_prompt = (
+ "Last Round Summary:\n"
+ " - Items to split: {last_quantities}\n"
+ " - Your per-item values: {last_value_agent}\n"
+ " - {other_agent}'s per-item values: {last_value_coagent}\n"
+ " - You proposed: {last_split_agent}\n"
+ " - You earned: {last_points_agent} points\n"
+ " - {other_agent} proposed: {last_split_coagent}\n"
+ " - {other_agent} earned: {last_points_coagent} points\n"
+ " - Round Complete.\n"
+ )
+ self.send_split_prompt = "Submit Your Proposal\n" "Respond as {proposal_style}"
+
+ def get_message_regex(self, observation: NoPressObs) -> str:
+ return r"^$" # No messages allowed
+
+ def get_split_regex(self, observation: NoPressObs) -> str:
+ items = list(observation.quantities.keys())
+ # Accept both singular and plural forms
+ item_pattern = "|".join(
+ [f"{item[:-1]}s?" if item.endswith("s") else f"{item}s?" for item in items]
+ )
+ regex = rf"(?i)Proposal:\s*((?:\s*(?P(10|[0-9]))\s*(?P- {item_pattern})\s*,?)+)"
+ return regex
+
+ def get_split_action(self, policy_output: str, observation: NoPressObs) -> Split:
+ items = list(observation.quantities.keys())
+ import re as _re
+
+ split_regex = self.get_split_regex(observation)
+ items_given_to_self = {item: 0 for item in items}
+ m = _re.match(split_regex, policy_output.strip())
+ if m:
+ # Find all (number, item) pairs
+ item_pattern = "|".join(
+ [
+ f"{item[:-1]}s?" if item.endswith("s") else f"{item}s?"
+ for item in items
+ ]
+ )
+ inner_regex = rf"(?i)(10|[0-9])\s*({item_pattern})"
+
+ def normalize_item_name(item_str):
+ for orig in items:
+ if item_str.lower() == orig.lower():
+ return orig
+ if orig.endswith("s") and item_str.lower() == orig[:-1].lower():
+ return orig
+ if (
+ not orig.endswith("s")
+ and item_str.lower() == orig.lower() + "s"
+ ):
+ return orig
+
+ for num, item in _re.findall(inner_regex, m.group(1)):
+ items_given_to_self[normalize_item_name(item)] = int(num)
+ return Split(items_given_to_self=items_given_to_self)
diff --git a/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_simulation.py
new file mode 100644
index 0000000000000000000000000000000000000000..d182187cc72c889a76f2d1c5be4b3afb6b923ed8
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_simulation.py
@@ -0,0 +1,168 @@
+import copy
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Any, Dict, List, Literal, Tuple
+
+from mllm.markov_games.negotiation.nego_simulation import (
+ NegotiationObs,
+ NegotiationSimulation,
+ NegotiationState,
+ Split,
+ compute_tas_style_rewards,
+)
+
+AgentId = str
+
+
+@dataclass
+class NoPressState(NegotiationState):
+ pass
+
+
+@dataclass
+class NoPressObs(NegotiationObs):
+ other_value: Dict[str, float]
+
+
+class NoPressSimulation(NegotiationSimulation):
+ def __init__(
+ self,
+ game_type: Literal["10-1-exclusive", "10-1-ties", "1-to-20"] = "1-to-20",
+ same_round_value: bool = True,
+ atleast_one_conflict: bool = False,
+ *args,
+ **kwargs,
+ ):
+ self.game_type = game_type
+ self.same_round_value = same_round_value
+ self.atleast_one_conflict = atleast_one_conflict
+ super().__init__(*args, **kwargs)
+
+ def _sample_values(self) -> Dict[AgentId, dict]:
+ values = defaultdict(dict)
+ if self.state is None:
+ item_types = self.item_types
+ else:
+ item_types = list(self.state.quantities.keys())
+ while True:
+ for item in item_types:
+ if self.game_type == "10-1-exclusive":
+ v = int(self.rng.choice([1, 10]))
+ values[self.agent_ids[0]][item] = v
+ values[self.agent_ids[1]][item] = 10 if v == 1 else 1
+ elif self.game_type == "10-1-ties":
+ for aid in self.agent_ids:
+ values[aid][item] = int(self.rng.choice([1, 10]))
+ elif self.game_type == "1-to-20":
+ for aid in self.agent_ids:
+ values[aid][item] = int(self.rng.integers(1, 21))
+ if self.atleast_one_conflict:
+ has_conflict = False
+ for item in item_types:
+ agent_values_for_item = [
+ values[aid][item] for aid in self.agent_ids
+ ]
+ if len(set(agent_values_for_item)) > 1:
+ has_conflict = True
+ break
+ if not has_conflict:
+ continue
+ agent_values = [sum(v.values()) for v in values.values()]
+ if len(set(agent_values)) == 1 or not self.same_round_value:
+ break
+ return values
+
+ def _sample_quantities(self) -> Dict[str, int]:
+ return {item.lower(): 10 for item in self.item_types}
+
+ def set_new_round_of_variant(self):
+ self.state.quantities = self._sample_quantities()
+ self.state.values = self._sample_values()
+ self.state.split_phase = True
+
+ def get_info_of_variant(
+ self, state: NegotiationState, actions: Dict[AgentId, Any]
+ ) -> Dict[str, Any]:
+ return {
+ "quantities": copy.deepcopy(state.quantities),
+ "values": copy.deepcopy(state.values),
+ "splits": copy.deepcopy(state.splits),
+ }
+
+ def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]:
+ return compute_tas_style_rewards(
+ self.agent_ids, self.state.values, splits, self.state.quantities
+ )
+
+ def get_obs(self):
+ return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids}
+
+ def get_obs_agent(self, agent_id):
+ other_id = self._other(agent_id)
+ last_value_coagent = (
+ None
+ if self.state.previous_values is None
+ else self.state.previous_values.get(other_id)
+ )
+ last_points_coagent = (
+ None
+ if self.state.previous_points is None
+ else round(self.state.previous_points.get(other_id), 1)
+ )
+ last_value_agent = (
+ None
+ if self.state.previous_values is None
+ else self.state.previous_values.get(agent_id)
+ )
+ last_points_agent = (
+ None
+ if self.state.previous_points is None
+ else round(self.state.previous_points.get(agent_id), 1)
+ )
+ last_split_coagent = None
+ last_split_agent = None
+ if self.state.previous_splits is not None:
+ last_split_coagent = self.state.previous_splits[
+ other_id
+ ].items_given_to_self
+ last_split_agent = self.state.previous_splits[agent_id].items_given_to_self
+ obs = NoPressObs(
+ round_nb=self.state.round_nb,
+ last_message="",
+ quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round,
+ current_agent=self.state.current_agent,
+ other_agent=self.agent_id_to_name[other_id],
+ quantities=self.state.quantities,
+ item_types=self.item_types,
+ value=self.state.values[agent_id],
+ split_phase=self.state.split_phase,
+ last_split_agent=last_split_agent,
+ last_value_agent=last_value_agent,
+ last_points_agent=last_points_agent,
+ last_split_coagent=last_split_coagent,
+ last_value_coagent=last_value_coagent,
+ last_points_coagent=last_points_coagent,
+ other_value=self.state.values[other_id],
+ last_quantities=self.state.previous_quantities,
+ )
+ return obs
+
+ def reset(self):
+ start_agent = self.agent_ids[self._starting_agent_index]
+ quantities = self._sample_quantities()
+ values = self._sample_values()
+ self.state = NoPressState(
+ round_nb=0,
+ last_message="",
+ current_agent=start_agent,
+ quantities=quantities,
+ values=values,
+ previous_values=None,
+ splits={aid: None for aid in self.agent_ids},
+ nb_messages_sent={aid: 0 for aid in self.agent_ids},
+ split_phase=True,
+ previous_splits=None,
+ previous_points=None,
+ previous_quantities=None,
+ )
+ return self.get_obs()
diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_agent.py b/src_code_for_reproducibility/markov_games/negotiation/tas_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..002160873969ab7292f0f62a091e12ec376022c6
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/negotiation/tas_agent.py
@@ -0,0 +1,108 @@
+from mllm.markov_games.negotiation.nego_agent import NegotiationAgent
+from mllm.markov_games.negotiation.nego_simulation import Split
+from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitObs
+
+
+class TrustAndSplitAgent(NegotiationAgent):
+ def __init__(self, num_message_chars, *args, **kwargs):
+ self.num_message_chars = num_message_chars
+ super().__init__(*args, **kwargs)
+ self.intro_prompt = (
+ "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n"
+ "Setup:\n"
+ "1. The game has multiple independent rounds.\n"
+ "2. In each round, there are multiple items to split between the two agents.\n"
+ "3. Both agents are assigned a per-item value between 1 and 20 (inclusive) in each round.\n"
+ "4. You can only observe your own per-item values.\n"
+ "5. Because assignments are random, both agents are equally likely to have same expected per-item value.\n"
+ "\n"
+ "Protocol:\n"
+ "1. At the start of the round, one agent begins the conversation. The starting role alternates each round.\n"
+ "2. Agents exchange a short chat ({quota_messages_per_agent_per_round} messages per round per agent) to negotiate how to split the item.\n"
+ " - Use this chat to communicate your private per-item value to make informed proposals.\n"
+ "3. After the chat, both agents simultaneously propose the amount of each item they will keep.\n"
+ "4. If the total sum of proposals is less than or equal to the item quantity, both agents receive their proposed amounts.\n"
+ "5. If the total sum of proposals exceeds the item quantity, they are allocated proportionally.\n"
+ "6. Your points for the round = (amount you receive per item) x (your per-item value for that round), added across all items.\n"
+ "7. Points are accumulated across rounds.\n"
+ "Your goal: {goal}\n"
+ )
+ self.new_round_prompt = (
+ "A New Round Begins\n"
+ "The items to split are {quantities}.\n"
+ "Your per-item values are {value}."
+ )
+ self.last_round_prompt = (
+ "Last Round Summary:\n"
+ " - Items to split: {last_quantities}\n"
+ " - Your per-item values: {last_value_agent}\n"
+ " - {other_agent}'s per-item values: {last_value_coagent}\n"
+ " - You proposed: {last_split_agent}\n"
+ " - You earned: {last_points_agent} points\n"
+ " - {other_agent} proposed: {last_split_coagent}\n"
+ " - {other_agent} earned: {last_points_coagent} points\n"
+ " - Round Complete.\n"
+ )
+ self.send_split_prompt = (
+ "Message quota is finished for this round.\n"
+ "{other_agent} has finalized their proposal.\n"
+ "Submit your finalization now\n"
+ "Respond with {proposal_style2}"
+ )
+ # self.wait_for_message_prompt = "Wait for {other_agent} to send a message..."
+ self.wait_for_message_prompt = ""
+ self.last_message_prompt = "{other_agent} said: {last_message}"
+ # self.send_message_prompt = (
+ # f"Send your message now (max {self.num_message_chars} chars)."
+ # )
+ self.send_message_prompt = f"Send your message now in ... (<={self.num_message_chars} chars)."
+
+ def get_message_regex(self, observation: TrustAndSplitObs) -> str:
+ return rf"[\s\S]{{0,{self.num_message_chars}}}"
+
+ # def get_message_regex(self, observation: TrustAndSplitObs) -> str:
+ # return rf"(?s).{{0,{self.num_message_chars}}}"
+
+ def get_split_regex(self, observation: TrustAndSplitObs) -> str:
+ items = list(observation.quantities.keys())
+ # Accept both singular and plural forms
+ item_pattern = "|".join(
+ [f"{item[:-1]}s?" if item.endswith("s") else f"{item}s?" for item in items]
+ )
+ regex = rf"(?i) ?((?:\s*(?P(10|[0-9]))\s*(?P
- {item_pattern})\s*,?)+) ?
"
+ return regex
+
+ def get_split_action(
+ self, policy_output: str, observation: TrustAndSplitObs
+ ) -> Split:
+ items = list(observation.quantities.keys())
+ import re as _re
+
+ split_regex = self.get_split_regex(observation)
+ items_given_to_self = {item: 0 for item in items}
+ m = _re.match(split_regex, policy_output.strip())
+ if m:
+ # Find all (number, item) pairs
+ item_pattern = "|".join(
+ [
+ f"{item[:-1]}s?" if item.endswith("s") else f"{item}s?"
+ for item in items
+ ]
+ )
+ inner_regex = rf"(?i)(10|[0-9])\s*({item_pattern})"
+
+ def normalize_item_name(item_str):
+ for orig in items:
+ if item_str.lower() == orig.lower():
+ return orig
+ if orig.endswith("s") and item_str.lower() == orig[:-1].lower():
+ return orig
+ if (
+ not orig.endswith("s")
+ and item_str.lower() == orig.lower() + "s"
+ ):
+ return orig
+
+ for num, item in _re.findall(inner_regex, m.group(1)):
+ items_given_to_self[normalize_item_name(item)] = int(num)
+ return Split(items_given_to_self=items_given_to_self)
diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py b/src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..e711c2a65d336e4d9b991c68662069e96b4dfee8
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py
@@ -0,0 +1,118 @@
+import copy
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any, Dict, List, Tuple
+
+from mllm.markov_games.agent import Agent
+from mllm.markov_games.negotiation.nego_agent import (
+ Message,
+ NegotiationAgent,
+ NegotiationAgentState,
+ Split,
+)
+from mllm.markov_games.negotiation.tas_rps_simulation import TrustAndSplitRPSObs
+from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn
+
+
+class TrustAndSplitRPSAgent(NegotiationAgent):
+ def __init__(
+ self,
+ num_message_chars: int,
+ message_start_end_format: bool = False,
+ proposal_start_end_format: bool = False,
+ *args,
+ **kwargs,
+ ):
+ self.num_message_chars = num_message_chars
+ self.message_start_end_format = message_start_end_format
+ self.proposal_start_end_format = proposal_start_end_format
+ super().__init__(*args, **kwargs)
+ self.intro_prompt = (
+ "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n"
+ "\n"
+ "Setup:\n"
+ "1. The game has multiple independent rounds.\n"
+ "2. In each round, there are 10 coins to split between the two agents.\n"
+ "3. Each agent's per-coin value for that round is determined as follows:\n"
+ " - Both agents are randomly assigned a rock, paper or scissors hands\n"
+ " - Rock has the upper hand over scissors, scissors has the upper hand over paper and paper has the upper hand over rock.\n"
+ " - The agent with the upper hand has a per-coin value of 10.\n"
+ " - The agent with the lower hand has a per-coin value of 1.\n"
+ "4. You only see your own hand, but you may communicate it in messages and infer your value based on the other agent's hand.\n"
+ "5. Over many rounds both agents are equally likely to have the upper and lower hand.\n"
+ "\n"
+ "Protocol:\n"
+ "1. At the start of the round, one agent begins the conversation. The starting role alternates each round.\n"
+ "2. Agents exchange a short chat ({quota_messages_per_agent_per_round} messages per round per agent) to negotiate how to split the 10 coins.\n"
+ " - Use this chat to communicate your hand so that both agents can determine their per-coin values.\n"
+ "3. After the chat, both agents simultaneously propose how many coins they keep.\n"
+ "4. If the total sum of proposals is less than or equal to 10, both agents receive their proposals.\n"
+ "5. If the total sum of proposals exceeds 10, the coins are allocated proportionally.\n"
+ "6. Your points for the round = (coins you receive) x (your per-coin value for that round). \n"
+ "7. The points are accumulated across rounds.\n"
+ "Your goal: {goal}\n"
+ )
+ self.new_round_prompt = (
+ "A New Round Begins\n"
+ "Your hand is {hand}. You don't know {other_agent}'s hand yet.\n"
+ )
+ # self.last_round_prompt = (
+ # "Last Round Summary:\n"
+ # " - Your hand: {last_hand_agent}\n"
+ # " - {other_agent}'s hand: {last_hand_coagent}\n"
+ # " - Your value per coin: {last_value_agent}\n"
+ # " - {other_agent}'s value per coin: {last_value_coagent}\n"
+ # " - You proposed: {last_split_agent} coins\n"
+ # " - You earned: {last_points_agent} points\n"
+ # " - {other_agent} proposed: {last_split_coagent} coins\n"
+ # " - {other_agent} earned: {last_points_coagent} points\n"
+ # " - Round Complete.\n"
+ # )
+ self.last_round_prompt = "In the previous round, {other_agent} had a {last_hand_value_coagent} hand and proposed {last_split_coagent} coins.\n"
+ if self.proposal_start_end_format:
+ self.send_split_prompt = (
+ "Submit your proposal\n"
+ "Respond with <> x <> where x is an integer in [0, 10]."
+ )
+ else:
+ self.send_split_prompt = (
+ "Submit your proposal\n"
+ "Respond with x where x is an integer in [0, 10]."
+ )
+ self.wait_for_message_prompt = "Wait for {other_agent} to send a message..."
+ # self.wait_for_message_prompt = ""
+ self.last_message_prompt = "{other_agent} said: {last_message}"
+ if self.message_start_end_format:
+ self.send_message_prompt = f"Send your message now in <>...<> (<={self.num_message_chars} chars)."
+ else:
+ self.send_message_prompt = f"Send your message now in ... (<={self.num_message_chars} chars)."
+
+ def get_message_regex(self, observation: TrustAndSplitRPSObs) -> str:
+ if self.message_start_end_format:
+ return (
+ rf"<>[\s\S]{{0,{self.num_message_chars}}}<>"
+ )
+ else:
+ return rf"[\s\S]{{0,{self.num_message_chars}}}"
+
+ def get_split_regex(self, observation: TrustAndSplitRPSObs) -> str:
+ if self.proposal_start_end_format:
+ return r"<> ?(10|[0-9]) ?<>"
+ else:
+ return r" ?(10|[0-9]) ?"
+
+ def get_split_action(
+ self, policy_output: str, observation: TrustAndSplitRPSObs
+ ) -> Split:
+ import re as _re
+
+ if self.proposal_start_end_format:
+ m = _re.search(
+ r"<> ?(10|[0-9]) ?<>", policy_output
+ )
+ else:
+ m = _re.search(
+ r" ?(10|[0-9]) ?", policy_output
+ )
+ coins_int = int(m.group(1)) if m else int(policy_output)
+ return Split(items_given_to_self={"coins": coins_int})
diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_rps_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/tas_rps_simulation.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a4289f89a0574056024cdd5da0f8a676d331670
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/negotiation/tas_rps_simulation.py
@@ -0,0 +1,248 @@
+"""
+Trust-and-Split simulation.
+
+This environment models a simple bargaining game over 10 coins with messaging.
+Agents are assigned rock/paper/scissors hands, with the winner getting value 10 per coin
+and the loser getting value 1 per coin. Agents alternate sending messages for a fixed
+number of turns per round and then each submits a split proposal indicating how many
+coins they keep for themselves. Rewards are proportional if the proposed totals exceed 10.
+"""
+
+import copy
+from dataclasses import dataclass
+from typing import Any, Dict, List, Literal, Tuple
+
+from numpy.random import default_rng
+
+from mllm.markov_games.negotiation.nego_simulation import (
+ Message,
+ NegotiationObs,
+ NegotiationSimulation,
+ NegotiationState,
+ Split,
+ compute_tas_style_rewards,
+)
+from mllm.markov_games.rollout_tree import SimulationStepLog
+
+AgentId = str
+
+
+def _get_rps_winner(
+ hand1: Literal["rock", "paper", "scissors"],
+ hand2: Literal["rock", "paper", "scissors"],
+) -> Literal["rock", "paper", "scissors"]:
+ """Determine winner of rock-paper-scissors between two hands."""
+ if hand1 == hand2:
+ raise ValueError("Hands should be different")
+ if (
+ (hand1 == "rock" and hand2 == "scissors")
+ or (hand1 == "paper" and hand2 == "rock")
+ or (hand1 == "scissors" and hand2 == "paper")
+ ):
+ return hand1
+ else:
+ return hand2
+
+
+@dataclass
+class TrustAndSplitRPSState(NegotiationState):
+ hands: Dict[
+ AgentId, Literal["rock", "paper", "scissors"]
+ ] # rock, paper, or scissors
+ previous_hands: Dict[AgentId, Literal["rock", "paper", "scissors"]] | None
+
+
+@dataclass
+class TrustAndSplitRPSObs(NegotiationObs):
+ hand: Literal["rock", "paper", "scissors"]
+ last_hand_agent: Literal["rock", "paper", "scissors"] | None
+ last_hand_coagent: Literal["rock", "paper", "scissors"] | None
+ last_hand_value_coagent: Literal["upper", "lower"] | None
+
+
+class TrustAndSplitRPSSimulation(NegotiationSimulation):
+ def __init__(
+ self,
+ alternating_hands: bool = False,
+ alternating_mix_ratio: float = None,
+ *args,
+ **kwargs,
+ ):
+ self.alternating_hands = alternating_hands
+ self.alternating_mix_ratio = alternating_mix_ratio
+ super().__init__(*args, **kwargs)
+ if self.alternating_mix_ratio is not None:
+ if self.rng.random() < self.alternating_mix_ratio:
+ self.alternating_hands = True
+ else:
+ self.alternating_hands = False
+
+ def _sample_hands_and_values(
+ self,
+ alternate_hands: bool = False,
+ ) -> Tuple[Dict[AgentId, str], Dict[AgentId, float]]:
+ hands = ["rock", "paper", "scissors"]
+ if alternate_hands:
+ previous_hands = list(self.state.previous_hands.values())
+ hand1, hand2 = self.rng.choice(hands, size=2, replace=False)
+ winner = _get_rps_winner(hand1, hand2)
+ loser = hand1 if winner == hand2 else hand2
+ previous_winner = _get_rps_winner(previous_hands[0], previous_hands[1])
+ agent_hands, values = {}, {}
+ for agent_id in self.agent_ids:
+ if self.state.previous_hands[agent_id] == previous_winner:
+ agent_hands[agent_id] = loser
+ values[agent_id] = 1.0
+ else:
+ agent_hands[agent_id] = winner
+ values[agent_id] = 10.0
+ return agent_hands, values
+ else:
+ # Assign different hands to each agent
+ hand1, hand2 = self.rng.choice(hands, size=2, replace=False)
+
+ agent_hands = {self.agent_ids[0]: hand1, self.agent_ids[1]: hand2}
+
+ # Determine winner and assign values
+ winner = _get_rps_winner(hand1, hand2)
+ values = {}
+ for agent_id in self.agent_ids:
+ if agent_hands[agent_id] == winner:
+ values[agent_id] = 10.0 # Winner gets value 10
+ else:
+ values[agent_id] = 1.0 # Loser gets value 1
+
+ return agent_hands, values
+
+ def set_new_round_of_variant(self):
+ self.state.previous_hands = copy.deepcopy(self.state.hands)
+ new_hands, new_values = self._sample_hands_and_values(
+ alternate_hands=self.alternating_hands
+ )
+ self.state.hands = new_hands
+ self.state.values = new_values
+ # Quantities are constant in TAS
+ self.state.quantities = {"coins": 10}
+ self.state.split_phase = False
+
+ def get_info_of_variant(
+ self, state: NegotiationState, actions: Dict[AgentId, Any]
+ ) -> Dict[str, Any]:
+ return {
+ "quantities": copy.deepcopy(state.quantities),
+ "hands": copy.deepcopy(state.hands),
+ "values": copy.deepcopy(state.values),
+ "previous_hands": copy.deepcopy(state.previous_hands),
+ "previous_values": copy.deepcopy(state.previous_values),
+ "splits": copy.deepcopy(state.splits),
+ }
+
+ def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]:
+ return compute_tas_style_rewards(
+ self.agent_ids, self.state.values, splits, self.state.quantities
+ )
+
+ def get_obs_agent(self, agent_id):
+ """Returns observation for agent_id"""
+ other_id = self._other(agent_id)
+ last_value_coagent = (
+ None
+ if self.state.previous_values is None
+ else self.state.previous_values.get(other_id)
+ )
+ last_hand_coagent = (
+ None
+ if self.state.previous_hands is None
+ else self.state.previous_hands.get(other_id)
+ )
+ last_points_coagent = (
+ None
+ if self.state.previous_points is None
+ else round(self.state.previous_points.get(other_id), 1)
+ )
+ last_value_agent = (
+ None
+ if self.state.previous_values is None
+ else self.state.previous_values.get(agent_id)
+ )
+ last_hand_agent = (
+ None
+ if self.state.previous_hands is None
+ else self.state.previous_hands.get(agent_id)
+ )
+ last_points_agent = (
+ None
+ if self.state.previous_points is None
+ else round(self.state.previous_points.get(agent_id), 1)
+ )
+ last_split_coagent = None
+ last_split_agent = None
+ if self.state.previous_splits is not None:
+ last_split_coagent = self.state.previous_splits[
+ other_id
+ ].items_given_to_self["coins"]
+ last_split_agent = self.state.previous_splits[agent_id].items_given_to_self[
+ "coins"
+ ]
+ if last_hand_agent is None or last_hand_coagent is None:
+ last_hand_value_coagent = None
+ else:
+ winner = _get_rps_winner(last_hand_agent, last_hand_coagent)
+ last_hand_value_coagent = (
+ "upper" if winner == last_hand_coagent else "lower"
+ )
+ obs = TrustAndSplitRPSObs(
+ round_nb=self.state.round_nb,
+ last_message=self.state.last_message,
+ quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round,
+ current_agent=self.state.current_agent,
+ other_agent=self.agent_id_to_name[other_id],
+ quantities={"coins": 10},
+ item_types=self.item_types,
+ value=self.state.values[agent_id],
+ split_phase=self.state.split_phase,
+ last_split_agent=last_split_agent,
+ last_value_agent=last_value_agent,
+ last_points_agent=last_points_agent,
+ last_split_coagent=last_split_coagent,
+ last_value_coagent=last_value_coagent,
+ last_points_coagent=last_points_coagent,
+ hand=self.state.hands[agent_id],
+ last_hand_coagent=last_hand_coagent,
+ last_hand_agent=last_hand_agent,
+ last_quantities=self.state.previous_quantities,
+ last_hand_value_coagent=last_hand_value_coagent,
+ )
+ return obs
+
+ def get_state(self):
+ return self.state
+
+ def get_safe_copy(self):
+ """Return a safe copy of the simulation."""
+ simulation_copy = copy.copy(self)
+ simulation_copy.state = copy.deepcopy(self.state)
+ return simulation_copy
+
+ def reset(self):
+ """Initialize and return initial observations"""
+ # Decide starting agent alternating across resets for determinism
+ start_agent = self.agent_ids[self._starting_agent_index]
+ hands, values = self._sample_hands_and_values()
+ self.state = TrustAndSplitRPSState(
+ round_nb=0,
+ last_message="",
+ current_agent=start_agent,
+ quantities={"coins": 10},
+ values=values,
+ splits={aid: None for aid in self.agent_ids},
+ nb_messages_sent={aid: 0 for aid in self.agent_ids},
+ previous_values=None,
+ previous_splits=None,
+ previous_points=None,
+ split_phase=False,
+ hands=hands,
+ previous_hands=None,
+ previous_quantities=None,
+ )
+ return self.get_obs()
diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_simple_agent.py b/src_code_for_reproducibility/markov_games/negotiation/tas_simple_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4439b53a04e8efe4553cb1aa0d85459a6e90c9d
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/negotiation/tas_simple_agent.py
@@ -0,0 +1,90 @@
+from mllm.markov_games.negotiation.nego_agent import NegotiationAgent
+from mllm.markov_games.negotiation.nego_simulation import Split
+from mllm.markov_games.negotiation.tas_simple_simulation import TrustAndSplitSimpleObs
+
+
+class TrustAndSplitSimpleAgent(NegotiationAgent):
+ def __init__(
+ self,
+ num_message_chars,
+ message_start_end_format: bool = False,
+ proposal_start_end_format: bool = False,
+ *args,
+ **kwargs,
+ ):
+ self.num_message_chars = num_message_chars
+ self.message_start_end_format = message_start_end_format
+ self.proposal_start_end_format = proposal_start_end_format
+ super().__init__(*args, **kwargs)
+ self.intro_prompt = (
+ "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n"
+ "Setup:\n"
+ "1. The game has multiple independent rounds.\n"
+ "2. In each round, there are 10 coins to split between the two agents.\n"
+ "3. Both agents are assigned a per-coin value between 1 and 10 (inclusive) in each round.\n"
+ "4. You can only observe your own per-coin value.\n"
+ "5. Because assignments are random, both agents are equally likely to have same expected per-coin value.\n"
+ "\n"
+ "Protocol:\n"
+ "1. At the start of the round, one agent begins the conversation. The starting role alternates each round.\n"
+ "2. Agents exchange a short chat ({quota_messages_per_agent_per_round} messages per round per agent) to negotiate how to split the coins.\n"
+ " - Use this chat to communicate your private per-coin value to make informed proposals.\n"
+ "3. After the chat, both agents simultaneously propose how many coins they keep.\n"
+ "4. If the total sum of proposals is less than or equal to 10, both agents receive their proposals.\n"
+ "5. If the total sum of proposals exceeds 10, the coins are allocated proportionally.\n"
+ "6. Your points for the round = (coins you receive) x (your per-coin value for that round). \n"
+ "7. Points are accumulated across rounds.\n"
+ "Your goal: {goal}\n"
+ )
+ self.new_round_prompt = (
+ "A New Round Begins\n"
+ "Your per-coin value is {value}. You don't know {other_agent}'s value yet.\n"
+ )
+ self.last_round_prompt = "In the previous round, {other_agent} had a {last_value_str_coagent} value and proposed {last_split_coagent} coins.\n"
+ if self.proposal_start_end_format:
+ self.send_split_prompt = (
+ "Submit your proposal\n"
+ "Respond with <> x <> where x is an integer in [0, 10]."
+ )
+ else:
+ self.send_split_prompt = (
+ "Submit your proposal\n"
+ "Respond with x where x is an integer in [0, 10]."
+ )
+ self.wait_for_message_prompt = "Wait for {other_agent} to send a message..."
+ # self.wait_for_message_prompt = ""
+ self.last_message_prompt = "{other_agent} said: {last_message}"
+ if self.message_start_end_format:
+ self.send_message_prompt = f"Send your message now in <>...<> (<={self.num_message_chars} chars)."
+ else:
+ self.send_message_prompt = f"Send your message now in ... (<={self.num_message_chars} chars)."
+
+ def get_message_regex(self, observation: TrustAndSplitSimpleObs) -> str:
+ if self.message_start_end_format:
+ return (
+ rf"<>[\s\S]{{0,{self.num_message_chars}}}<>"
+ )
+ else:
+ return rf"[\s\S]{{0,{self.num_message_chars}}}"
+
+ def get_split_regex(self, observation: TrustAndSplitSimpleObs) -> str:
+ if self.proposal_start_end_format:
+ return r"<> ?(10|[0-9]) ?<>"
+ else:
+ return r" ?(10|[0-9]) ?"
+
+ def get_split_action(
+ self, policy_output: str, observation: TrustAndSplitSimpleObs
+ ) -> Split:
+ import re as _re
+
+ if self.proposal_start_end_format:
+ m = _re.search(
+ r"<> ?(10|[0-9]) ?<>", policy_output
+ )
+ else:
+ m = _re.search(
+ r" ?(10|[0-9]) ?", policy_output
+ )
+ coins_int = int(m.group(1)) if m else int(policy_output)
+ return Split(items_given_to_self={"coins": coins_int})
diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_simple_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/tas_simple_simulation.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dbd0c43d73e3f7b18204b62e71d72b2df1d13e6
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/negotiation/tas_simple_simulation.py
@@ -0,0 +1,169 @@
+import copy
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Any, Dict, List, Literal
+
+from numpy.random import default_rng
+
+from mllm.markov_games.negotiation.nego_simulation import (
+ NegotiationObs,
+ NegotiationSimulation,
+ NegotiationState,
+ Split,
+ compute_tas_style_rewards,
+)
+
+AgentId = str
+
+
+@dataclass
+class TrustAndSplitSimpleState(NegotiationState):
+ pass
+
+
+@dataclass
+class TrustAndSplitSimpleObs(NegotiationObs):
+ last_value_str_coagent: str | None
+
+
+class TrustAndSplitSimpleSimulation(NegotiationSimulation):
+ def __init__(
+ self,
+ game_type: Literal["10-1-exclusive", "1-to-10"] = "1-to-10",
+ dist_type: Literal["uniform", "bimodal"] = "uniform",
+ beta_dist_alpha: float = 0.1,
+ beta_dist_beta: float = 0.1,
+ *args,
+ **kwargs,
+ ):
+ self.game_type = game_type
+ self.dist_type = dist_type
+ self.beta_dist_alpha = beta_dist_alpha
+ self.beta_dist_beta = beta_dist_beta
+ super().__init__(*args, **kwargs)
+
+ def _sample_values(self) -> Dict[AgentId, dict]:
+ values = {}
+ while True:
+ if self.game_type == "10-1-exclusive":
+ v = int(self.rng.choice([1, 10]))
+ values[self.agent_ids[0]] = v
+ values[self.agent_ids[1]] = 10 if v == 1 else 1
+ elif self.game_type == "1-to-10":
+ for aid in self.agent_ids:
+ if self.dist_type == "uniform":
+ values[aid] = int(self.rng.integers(1, 11))
+ elif self.dist_type == "bimodal":
+ alpha, beta = self.beta_dist_alpha, self.beta_dist_beta
+ values[aid] = int(round(self.rng.beta(alpha, beta) * 9) + 1)
+ if len(set(values.values())) != 1:
+ break
+ return values
+
+ def _sample_quantities(self) -> Dict[str, int]:
+ return {"coins": 10}
+
+ def set_new_round_of_variant(self):
+ self.state.quantities = self._sample_quantities()
+ self.state.values = self._sample_values()
+ self.state.split_phase = False
+
+ def get_info_of_variant(
+ self, state: NegotiationState, actions: Dict[AgentId, Any]
+ ) -> Dict[str, Any]:
+ return {
+ "quantities": copy.deepcopy(state.quantities),
+ "values": copy.deepcopy(state.values),
+ # "previous_values": copy.deepcopy(state.previous_values),
+ "splits": copy.deepcopy(state.splits),
+ }
+
+ def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]:
+ return compute_tas_style_rewards(
+ self.agent_ids, self.state.values, splits, self.state.quantities
+ )
+
+ def get_obs(self):
+ return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids}
+
+ def get_obs_agent(self, agent_id):
+ other_id = self._other(agent_id)
+ last_value_coagent = (
+ None
+ if self.state.previous_values is None
+ else self.state.previous_values.get(other_id)
+ )
+ last_points_coagent = (
+ None
+ if self.state.previous_points is None
+ else round(self.state.previous_points.get(other_id), 1)
+ )
+ last_value_agent = (
+ None
+ if self.state.previous_values is None
+ else self.state.previous_values.get(agent_id)
+ )
+ last_points_agent = (
+ None
+ if self.state.previous_points is None
+ else round(self.state.previous_points.get(agent_id), 1)
+ )
+ last_split_coagent = None
+ last_split_agent = None
+ if self.state.previous_splits is not None:
+ last_split_coagent = self.state.previous_splits[
+ other_id
+ ].items_given_to_self["coins"]
+ last_split_agent = self.state.previous_splits[agent_id].items_given_to_self[
+ "coins"
+ ]
+ if last_value_agent is None or last_value_coagent is None:
+ last_value_str_coagent = None
+ else:
+ if last_value_coagent > last_value_agent:
+ last_value_str_coagent = "higher"
+ elif last_value_coagent < last_value_agent:
+ last_value_str_coagent = "lower"
+ else:
+ raise ValueError("Should not be equal values")
+
+ obs = TrustAndSplitSimpleObs(
+ round_nb=self.state.round_nb,
+ last_message=self.state.last_message,
+ quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round,
+ current_agent=self.state.current_agent,
+ other_agent=self.agent_id_to_name[other_id],
+ quantities=self.state.quantities,
+ item_types=self.item_types,
+ value=self.state.values[agent_id],
+ split_phase=self.state.split_phase,
+ last_split_agent=last_split_agent,
+ last_value_agent=last_value_agent,
+ last_points_agent=last_points_agent,
+ last_split_coagent=last_split_coagent,
+ last_value_coagent=last_value_coagent,
+ last_points_coagent=last_points_coagent,
+ last_quantities=self.state.previous_quantities,
+ last_value_str_coagent=last_value_str_coagent,
+ )
+ return obs
+
+ def reset(self):
+ start_agent = self.agent_ids[self._starting_agent_index]
+ quantities = self._sample_quantities()
+ values = self._sample_values()
+ self.state = TrustAndSplitSimpleState(
+ round_nb=0,
+ last_message="",
+ current_agent=start_agent,
+ quantities=quantities,
+ values=values,
+ previous_values=None,
+ splits={aid: None for aid in self.agent_ids},
+ nb_messages_sent={aid: 0 for aid in self.agent_ids},
+ split_phase=False,
+ previous_splits=None,
+ previous_points=None,
+ previous_quantities=None,
+ )
+ return self.get_obs()
diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/tas_simulation.py
new file mode 100644
index 0000000000000000000000000000000000000000..5499a146e9da491757a8105965b2d210f8327134
--- /dev/null
+++ b/src_code_for_reproducibility/markov_games/negotiation/tas_simulation.py
@@ -0,0 +1,172 @@
+import copy
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Any, Dict, List, Literal
+
+from numpy.random import default_rng
+
+from mllm.markov_games.negotiation.nego_simulation import (
+ NegotiationObs,
+ NegotiationSimulation,
+ NegotiationState,
+ Split,
+ compute_tas_style_rewards,
+)
+
+AgentId = str
+
+
+@dataclass
+class TrustAndSplitState(NegotiationState):
+ pass
+
+
+@dataclass
+class TrustAndSplitObs(NegotiationObs):
+ pass
+
+
+class TrustAndSplitSimulation(NegotiationSimulation):
+ def __init__(
+ self,
+ game_type: Literal["10-1-exclusive", "10-1-ties", "1-to-20"] = "1-to-20",
+ same_round_value: bool = True,
+ atleast_one_conflict: bool = False,
+ *args,
+ **kwargs,
+ ):
+ self.game_type = game_type
+ self.same_round_value = same_round_value
+ self.atleast_one_conflict = atleast_one_conflict
+ super().__init__(*args, **kwargs)
+
+ def _sample_values(self) -> Dict[AgentId, dict]:
+ values = defaultdict(dict)
+ if self.state is None:
+ item_types = self.item_types
+ else:
+ item_types = list(self.state.quantities.keys())
+ while True:
+ for item in item_types:
+ if self.game_type == "10-1-exclusive":
+ v = int(self.rng.choice([1, 10]))
+ values[self.agent_ids[0]][item] = v
+ values[self.agent_ids[1]][item] = 10 if v == 1 else 1
+ elif self.game_type == "10-1-ties":
+ for aid in self.agent_ids:
+ values[aid][item] = int(self.rng.choice([1, 10]))
+ elif self.game_type == "1-to-20":
+ for aid in self.agent_ids:
+ values[aid][item] = int(self.rng.integers(1, 21))
+ agent_values = [sum(v.values()) for v in values.values()]
+ if self.atleast_one_conflict:
+ has_conflict = False
+ for item in item_types:
+ agent_values_for_item = [
+ values[aid][item] for aid in self.agent_ids
+ ]
+ if (
+ len(set(agent_values_for_item)) > 1
+ ): # Different values for this item
+ has_conflict = True
+ break
+ if not has_conflict:
+ continue
+ if len(set(agent_values)) == 1 or not self.same_round_value:
+ break
+ return values
+
+ def _sample_quantities(self) -> Dict[str, int]:
+ return {item.lower(): 10 for item in self.item_types}
+
+ def set_new_round_of_variant(self):
+ self.state.quantities = self._sample_quantities()
+ self.state.values = self._sample_values()
+ self.state.split_phase = False
+
+ def get_info_of_variant(
+ self, state: NegotiationState, actions: Dict[AgentId, Any]
+ ) -> Dict[str, Any]:
+ return {
+ "quantities": copy.deepcopy(state.quantities),
+ "values": copy.deepcopy(state.values),
+ # "previous_values": copy.deepcopy(state.previous_values),
+ "splits": copy.deepcopy(state.splits),
+ }
+
+ def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]:
+ return compute_tas_style_rewards(
+ self.agent_ids, self.state.values, splits, self.state.quantities
+ )
+
+ def get_obs(self):
+ return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids}
+
+ def get_obs_agent(self, agent_id):
+ other_id = self._other(agent_id)
+ last_value_coagent = (
+ None
+ if self.state.previous_values is None
+ else self.state.previous_values.get(other_id)
+ )
+ last_points_coagent = (
+ None
+ if self.state.previous_points is None
+ else round(self.state.previous_points.get(other_id), 1)
+ )
+ last_value_agent = (
+ None
+ if self.state.previous_values is None
+ else self.state.previous_values.get(agent_id)
+ )
+ last_points_agent = (
+ None
+ if self.state.previous_points is None
+ else round(self.state.previous_points.get(agent_id), 1)
+ )
+ last_split_coagent = None
+ last_split_agent = None
+ if self.state.previous_splits is not None:
+ last_split_coagent = self.state.previous_splits[
+ other_id
+ ].items_given_to_self
+ last_split_agent = self.state.previous_splits[agent_id].items_given_to_self
+ obs = TrustAndSplitObs(
+ round_nb=self.state.round_nb,
+ last_message=self.state.last_message,
+ quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round,
+ current_agent=self.state.current_agent,
+ other_agent=self.agent_id_to_name[other_id],
+ quantities=self.state.quantities,
+ item_types=self.item_types,
+ value=self.state.values[agent_id],
+ split_phase=self.state.split_phase,
+ last_split_agent=last_split_agent,
+ last_value_agent=last_value_agent,
+ last_points_agent=last_points_agent,
+ last_split_coagent=last_split_coagent,
+ last_value_coagent=last_value_coagent,
+ last_points_coagent=last_points_coagent,
+ last_quantities=self.state.previous_quantities,
+ )
+ return obs
+
+ def reset(self):
+ start_agent = self.agent_ids[self._starting_agent_index]
+ quantities = self._sample_quantities()
+ values = self._sample_values()
+ self.state = TrustAndSplitState(
+ round_nb=0,
+ last_message="",
+ current_agent=start_agent,
+ quantities=quantities,
+ values=values,
+ previous_values=None,
+ splits={aid: None for aid in self.agent_ids},
+ nb_messages_sent={aid: 0 for aid in self.agent_ids},
+ split_phase=False,
+ previous_splits=None,
+ previous_points=None,
+ previous_quantities=None,
+ )
+ return self.get_obs()
diff --git a/src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-312.pyc b/src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c77324d067890d56a074a4e427cb5e9b6c59df01
Binary files /dev/null and b/src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc b/src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a89b2fde9bcdd46de8e225fa9ee9e67ae9f5b9c
Binary files /dev/null and b/src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-312.pyc b/src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da0a8fd5d1dceb7ea6f095412c40fd34b8534e9b
Binary files /dev/null and b/src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc b/src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..650e0c11acbe423288c4a4a4cd005fe2be810eea
Binary files /dev/null and b/src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/models/large_language_model_local.py b/src_code_for_reproducibility/models/large_language_model_local.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eac1c32c0233cf04106fa12be333ebf74319c2a
--- /dev/null
+++ b/src_code_for_reproducibility/models/large_language_model_local.py
@@ -0,0 +1,384 @@
+"""
+TODO: Figure out how to tweak SGlang not to go OOM when batch size is 32. See https://github.com/sgl-project/sglang/issues/6309.
+"""
+
+import logging
+import os
+import re
+import sys
+import uuid
+from collections.abc import Callable
+from copy import deepcopy
+from datetime import datetime
+from typing import Literal
+
+import httpx
+import requests
+import torch
+import torch.nn as nn
+
+# from sglang.utils import (
+# launch_server_cmd,
+# print_highlight,
+# terminate_process,
+# wait_for_server,
+# )
+from torch.optim import SGD, Adam, AdamW, RMSprop
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from trl import AutoModelForCausalLMWithValueHead
+
+from mllm.chat_utils.apply_template import chat_turns_to_token_ids
+from mllm.markov_games.rollout_tree import ChatTurn
+from mllm.models.adapter_training_wrapper import AdapterWrapper
+from mllm.models.inference_backend import LLMInferenceOutput
+from mllm.models.inference_backend_dummy import DummyInferenceBackend
+from mllm.models.inference_backend_sglang import SGLangOfflineBackend
+from mllm.models.inference_backend_vllm import VLLMAsyncBackend
+
+logger = logging.getLogger(__name__)
+logger.addHandler(logging.StreamHandler(sys.stdout))
+
+AdapterID = str
+PolicyID = str
+
+
+class LeanLocalLLM:
+ """
+ TOWRITE
+ """
+
+ def __init__(
+ self,
+ llm_id: str = "base_llm",
+ model_name: str = "Qwen/Qwen3-4B-Instruct-2507",
+ device: str = "cuda",
+ hf_kwargs: dict = {},
+ adapter_configs: dict = {},
+ output_directory: str = "./models/",
+ inference_backend: Literal["vllm", "sglang", "dummy"] = "vllm",
+ inference_backend_sampling_params: dict = {},
+ inference_backend_init_kwargs: dict = {},
+ initial_adapter_paths: dict[str, str] | None = None,
+ initial_buffer_paths: list[str] | None = None,
+ enable_thinking: bool = None,
+ regex_max_attempts: int = -1,
+ max_thinking_characters: int = 0,
+ ):
+ self.inference_backend_name = inference_backend
+ self.output_directory = output_directory
+ self.llm_id = llm_id
+ self.device = torch.device(device) if device else torch.device("cuda")
+ self.model_name = model_name
+ self.adapter_configs = adapter_configs
+ self.adapter_ids = list(adapter_configs.keys())
+ self.enable_thinking = enable_thinking
+ self.regex_max_attempts = regex_max_attempts
+ self.initial_buffer_paths = initial_buffer_paths
+ self.max_thinking_characters = max_thinking_characters
+ self.regex_retries_count = 0
+
+ # Optional user-specified initial adapter weight locations (local or HF Hub)
+ # Format: {adapter_id: path_or_repo_id}
+ self.initial_adapter_paths: dict[str, str] | None = initial_adapter_paths
+
+ # Path management / imports
+ self.save_path = str(os.path.join(output_directory, model_name, "adapters"))
+ self.adapter_paths = {
+ adapter_id: os.path.join(self.save_path, adapter_id)
+ for adapter_id in self.adapter_ids
+ }
+ checkpoints_dir = os.path.join(self.output_directory, "checkpoints")
+ self.past_agent_adapter_paths = {}
+ if os.path.isdir(checkpoints_dir):
+ for dirname in os.listdir(checkpoints_dir):
+ dirpath = os.path.join(checkpoints_dir, dirname)
+ if os.path.isdir(dirpath):
+ self.past_agent_adapter_paths[f"{dirname}_buffer"] = os.path.join(
+ dirpath, "agent_adapter"
+ )
+ logger.info(
+ f"Loaded {len(self.past_agent_adapter_paths)} past agent adapters from checkpoints directory."
+ )
+ if self.initial_buffer_paths is not None:
+ previous_count = len(self.past_agent_adapter_paths)
+ for path in self.initial_buffer_paths:
+ if os.path.isdir(path):
+ for dirname in os.listdir(path):
+ dirpath = os.path.join(path, dirname)
+ if os.path.isdir(dirpath):
+ self.past_agent_adapter_paths[
+ f"{dirname}_buffer"
+ ] = os.path.join(dirpath, "agent_adapter")
+ else:
+ logger.warning(
+ f"Initial buffer path {path} does not exist or is not a directory."
+ )
+ logger.info(
+ f"Loaded {len(self.past_agent_adapter_paths) - previous_count} past agent adapters from user-specified initial buffer paths."
+ )
+ self.past_agent_adapter_ids = list(self.past_agent_adapter_paths.keys())
+
+ # ID management for tracking adapter versions
+ self.adapter_train_ids = {
+ adapter_id: self.short_id_generator() for adapter_id in self.adapter_ids
+ }
+ # Initialize tokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
+ # Setup padding token to be same as EOS token
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ self.weights_got_updated: dict[AdapterID, bool] = {
+ adapter_id: False for adapter_id in self.adapter_ids
+ }
+ self.weights_got_updated.update(
+ {adapter_id: False for adapter_id in self.past_agent_adapter_ids}
+ )
+ self.current_lora_request = None
+ self.currently_loaded_adapter_id = None
+
+ # ---------------------------------------------------------
+ # Init HF model, peft adapters
+ # ---------------------------------------------------------
+ self.shared_hf_llm = AutoModelForCausalLM.from_pretrained(
+ pretrained_model_name_or_path=model_name,
+ **hf_kwargs,
+ )
+ self.hf_adapters = {}
+ self.optimizers = {}
+ for adapter_id in self.adapter_ids:
+ # Prefer output-folder path if it exists; else fall back to user-specified initial path if provided
+ output_path = os.path.join(self.save_path, adapter_id)
+ chosen_path: str | None = None
+ if os.path.isdir(output_path) and os.listdir(output_path):
+ chosen_path = output_path
+ logger.info(
+ f"Initializing adapter '{adapter_id}': using existing weights from output folder '{chosen_path}'."
+ )
+ elif (
+ self.initial_adapter_paths and adapter_id in self.initial_adapter_paths
+ ):
+ chosen_path = self.initial_adapter_paths[adapter_id]
+ logger.info(
+ f"Initializing adapter '{adapter_id}': using provided initial path '{chosen_path}'."
+ )
+ else:
+ logger.info(
+ f"Initializing adapter '{adapter_id}': no initial weights provided or found; starting from scratch."
+ )
+ hf_adapter = AdapterWrapper(
+ shared_llm=self.shared_hf_llm,
+ adapter_id=adapter_id,
+ lora_config=adapter_configs[adapter_id],
+ path=chosen_path,
+ ).to(device)
+ self.hf_adapters[adapter_id] = hf_adapter
+ # Persist current state of all adapters (ensures remote loads are cached to disk)
+ self.export_adapters()
+
+ # ---------------------------------------------------------
+ # Init inference inference_backend
+ # ---------------------------------------------------------
+
+ if inference_backend == "sglang":
+ self.inference_backend = SGLangOfflineBackend(
+ model_name=self.model_name,
+ save_path=self.save_path,
+ adapter_paths=self.adapter_paths,
+ tokenizer=self.tokenizer,
+ kwargs=inference_backend_init_kwargs,
+ )
+ elif inference_backend == "vllm":
+ self.inference_backend = VLLMAsyncBackend(
+ model_name=self.model_name,
+ # adapter_paths=self.adapter_paths,
+ tokenizer=self.tokenizer,
+ engine_init_kwargs=inference_backend_init_kwargs,
+ sampling_params=inference_backend_sampling_params,
+ )
+ elif inference_backend == "dummy":
+ self.inference_backend = DummyInferenceBackend()
+ else:
+ raise ValueError(f"Unknown inference_backend: {inference_backend}")
+
+ def reset_regex_retries_count(self) -> None:
+ self.regex_retries_count = 0
+
+ def get_inference_policies(self) -> dict[PolicyID, Callable]:
+ """
+ TOWRITE
+ """
+ policies = {}
+ for adapter_id in self.adapter_ids:
+ # define policy func
+ async def policy(
+ state: list[ChatTurn],
+ agent_id: str,
+ regex: str | None = None,
+ _adapter_id=adapter_id,
+ ):
+ self.prepare_adapter_for_inference(adapter_id=_adapter_id)
+ response = await self.get_action(state, agent_id, regex)
+ return response
+
+ policies[self.llm_id + "/" + adapter_id] = policy
+
+ for adapter_id in self.past_agent_adapter_ids:
+ # define policy func
+ async def policy(
+ state: list[ChatTurn],
+ agent_id: str,
+ regex: str | None = None,
+ _adapter_id=adapter_id,
+ ):
+ self.prepare_adapter_for_inference(adapter_id=_adapter_id)
+ response = await self.get_action(state, agent_id, regex)
+ return response
+
+ policies[self.llm_id + "/" + adapter_id] = policy
+ return policies
+
+ def get_adapter_modules(self) -> dict[PolicyID, nn.Module]:
+ """
+ Returns wrappers over the adapters which allows them be
+ interfaced like regular PyTorch models.
+ # TODO: create the adapter wrappers here
+ See adapter_wrapper.py
+ """
+ trainable_objects = {an: self.hf_adapters[an] for an in self.adapter_ids}
+ return trainable_objects
+
+ async def toggle_training_mode(self) -> None:
+ for adn in self.adapter_ids:
+ self.adapter_train_ids[adn] = self.short_id_generator()
+ await self.inference_backend.toggle_training_mode()
+
+ async def toggle_eval_mode(self) -> None:
+ await self.inference_backend.toggle_eval_mode()
+
+ def prepare_adapter_for_inference(self, adapter_id: AdapterID) -> None:
+ self.inference_backend.prepare_adapter(
+ adapter_id,
+ adapter_path=self.adapter_paths.get(
+ adapter_id, self.past_agent_adapter_paths.get(adapter_id, None)
+ ),
+ weights_got_updated=self.weights_got_updated[adapter_id],
+ )
+ self.currently_loaded_adapter_id = adapter_id
+ self.weights_got_updated[adapter_id] = False
+
+ # def _make_prompt_text(self, prompt: list[dict]) -> str:
+ # if self.enable_thinking is not None:
+ # prompt_text = self.tokenizer.apply_chat_template(
+ # prompt,
+ # tokenize=False,
+ # add_generation_prompt=True,
+ # enable_thinking=self.enable_thinking,
+ # )
+ # else:
+ # prompt_text = self.tokenizer.apply_chat_template(
+ # prompt,
+ # tokenize=False,
+ # add_generation_prompt=True,
+ # )
+
+ # return prompt_text
+
+ async def get_action(
+ self, state: list[ChatTurn], agent_id: str, regex: str | None = None
+ ) -> ChatTurn:
+ current_regex = regex if self.regex_max_attempts == -1 else None
+ pattern = re.compile(regex) if regex else None
+ nb_attempts = 0
+ state = state[:]
+ while True:
+ context_token_ids = chat_turns_to_token_ids(
+ chats=state,
+ tokenizer=self.tokenizer,
+ enable_thinking=self.enable_thinking,
+ )
+ # print(f"context is {self.tokenizer.decode(context_token_ids)}")
+ policy_output = await self.inference_backend.generate(
+ input_token_ids=context_token_ids.tolist(),
+ extract_thinking=(self.max_thinking_characters > 0),
+ regex=current_regex,
+ )
+ # print(f"generated: {self.tokenizer.decode(policy_output.out_token_ids)}")
+ if (
+ pattern is None
+ or (pattern.fullmatch(policy_output.content))
+ or (nb_attempts >= self.regex_max_attempts)
+ ):
+ return ChatTurn(
+ agent_id=agent_id,
+ role="assistant",
+ content=policy_output.content,
+ reasoning_content=policy_output.reasoning_content,
+ out_token_ids=policy_output.out_token_ids,
+ log_probs=policy_output.log_probs,
+ is_state_end=False,
+ )
+ else:
+ self.regex_retries_count += 1
+ nb_attempts += 1
+ logger.warning(
+ f"Response {policy_output.content} did not match regex: {regex}, retry {nb_attempts}/{self.regex_max_attempts}"
+ )
+ if nb_attempts == self.regex_max_attempts:
+ current_regex = regex
+ # regex_prompt = ChatTurn(
+ # role="user",
+ # content=f"Invalid response format. Expected format (regex): {current_regex}\n Please try again and provide ONLY a response that matches this regex.",
+ # reasoning_content=None,
+ # log_probs=None,
+ # out_token_ids=None,
+ # is_state_end=False,
+ # )
+ # state.append(regex_prompt)
+
+ def export_adapters(self) -> None:
+ """
+ Any peft wrapper, by default, saves all adapters, not just the one currently loaded.
+ """
+
+ # New version of the adapters available
+ for adapter_id in self.adapter_ids:
+ self.weights_got_updated[adapter_id] = True
+ for adapter_id in self.past_agent_adapter_ids:
+ self.weights_got_updated[adapter_id] = True
+
+ # import random
+ # self.save_path = self.save_path + str(random.randint(1,500))
+ # print(f"Save path: {self.save_path}")
+ # self.adapter_paths = {adapter_id:os.path.join(self.save_path, adapter_id) for adapter_id in self.adapter_ids}
+
+ adapter_id = self.adapter_ids[0]
+ self.hf_adapters[adapter_id].save_pretrained(self.save_path)
+
+ def checkpoint_all_adapters(self, checkpoint_indicator: str) -> None:
+ """
+ Checkpoints all adapters to the configured output directory.
+ """
+ adapter_id = self.adapter_ids[0]
+ output_dir = os.path.join(self.output_directory, "checkpoints")
+ os.makedirs(output_dir, exist_ok=True)
+ date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+ agent_adapter_dir = f"{adapter_id}-{checkpoint_indicator}-{date_str}"
+ export_path = os.path.join(output_dir, agent_adapter_dir)
+ for adapter_id in self.adapter_ids:
+ if "agent" in adapter_id:
+ self.past_agent_adapter_paths[
+ f"{agent_adapter_dir}_buffer"
+ ] = os.path.join(export_path, adapter_id)
+ self.past_agent_adapter_ids.append(f"{agent_adapter_dir}_buffer")
+ self.weights_got_updated[f"{agent_adapter_dir}_buffer"] = False
+ self.hf_adapters[adapter_id].save_pretrained(export_path)
+
+ def short_id_generator(self) -> str:
+ """
+ Generates a short unique ID for tracking adapter versions.
+
+ Returns:
+ int: An 8-digit integer ID.
+ """
+ return str(uuid.uuid4().int)[:8]
diff --git a/src_code_for_reproducibility/utils/__pycache__/get_coagent_id.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/get_coagent_id.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05691ab328cc6ccf7a54fc828757fc12582c8ec5
Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/get_coagent_id.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/utils/__pycache__/short_id_gen.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/short_id_gen.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..10898f798ab2178b950e57100739d421c01f00f7
Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/short_id_gen.cpython-312.pyc differ
diff --git a/src_code_for_reproducibility/utils/__pycache__/wandb_utils.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/wandb_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09bd9177e48a4140805ea2695cb567ffe987d70a
Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/wandb_utils.cpython-312.pyc differ