dereckpichemila commited on
Commit
59b93f6
·
verified ·
1 Parent(s): 4728a06

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src_code_for_reproducibility/__pycache__/__init__.cpython-311.pyc +0 -0
  2. src_code_for_reproducibility/docs/source/conf.py +48 -0
  3. src_code_for_reproducibility/docs/source/environments/diplomacy.rst +459 -0
  4. src_code_for_reproducibility/docs/source/environments/dond.rst +410 -0
  5. src_code_for_reproducibility/docs/source/environments/ipd.rst +411 -0
  6. src_code_for_reproducibility/docs/source/media/runbatch.png +0 -0
  7. src_code_for_reproducibility/docs/source/src.environments.dond.dond_game.rst +7 -0
  8. src_code_for_reproducibility/docs/source/src.environments.dond.dond_log_funcs.rst +7 -0
  9. src_code_for_reproducibility/docs/source/src.environments.dond.dond_player.rst +7 -0
  10. src_code_for_reproducibility/docs/source/src.environments.env_imports.rst +7 -0
  11. src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_log_funcs.rst +7 -0
  12. src_code_for_reproducibility/docs/source/src.environments.ipd.rst +19 -0
  13. src_code_for_reproducibility/docs/source/src.environments.rst +25 -0
  14. src_code_for_reproducibility/docs/source/src.experiments.dond_run_train.rst +7 -0
  15. src_code_for_reproducibility/docs/source/src.generation.run_games.rst +7 -0
  16. src_code_for_reproducibility/docs/source/src.models.hf_agent.rst +7 -0
  17. src_code_for_reproducibility/docs/source/src.models.vllm_worker_wrap.rst +7 -0
  18. src_code_for_reproducibility/docs/source/src.run.rst +7 -0
  19. src_code_for_reproducibility/docs/source/src.training.reinforce_training.rst +7 -0
  20. src_code_for_reproducibility/docs/source/src.training.rl_convs_processing.rst +7 -0
  21. src_code_for_reproducibility/docs/source/src.utils.inherit_args.rst +7 -0
  22. src_code_for_reproducibility/docs/source/src.utils.log_gpu_usage.rst +7 -0
  23. src_code_for_reproducibility/docs/source/src.utils.log_statistics.rst +7 -0
  24. src_code_for_reproducibility/docs/source/src.utils.parallel_shuffle.rst +7 -0
  25. src_code_for_reproducibility/docs/source/src.utils.quick_stats.rst +7 -0
  26. src_code_for_reproducibility/docs/source/usage.rst +0 -0
  27. src_code_for_reproducibility/markov_games/__init__.py +0 -0
  28. src_code_for_reproducibility/markov_games/alternative_actions_runner.py +138 -0
  29. src_code_for_reproducibility/markov_games/diplomacy/diplomacy_env.py +230 -0
  30. src_code_for_reproducibility/markov_games/gather_and_export_utils.py +951 -0
  31. src_code_for_reproducibility/markov_games/mg_utils.py +77 -0
  32. src_code_for_reproducibility/markov_games/simulation.py +87 -0
  33. src_code_for_reproducibility/markov_games/statistics_runner.py +405 -0
  34. src_code_for_reproducibility/models/__init__.py +0 -0
  35. src_code_for_reproducibility/models/__pycache__/__init__.cpython-311.pyc +0 -0
  36. src_code_for_reproducibility/models/__pycache__/adapter_training_wrapper.cpython-311.pyc +0 -0
  37. src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-311.pyc +0 -0
  38. src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-311.pyc +0 -0
  39. src_code_for_reproducibility/models/__pycache__/inference_backend_sglang.cpython-311.pyc +0 -0
  40. src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-311.pyc +0 -0
  41. src_code_for_reproducibility/models/__pycache__/large_language_model_api.cpython-311.pyc +0 -0
  42. src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-311.pyc +0 -0
  43. src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-311.pyc +0 -0
  44. src_code_for_reproducibility/models/adapter_training_wrapper.py +89 -0
  45. src_code_for_reproducibility/models/inference_backend.py +35 -0
  46. src_code_for_reproducibility/models/inference_backend_dummy.py +53 -0
  47. src_code_for_reproducibility/models/inference_backend_sglang.py +86 -0
  48. src_code_for_reproducibility/models/inference_backend_sglang_local_server.py +127 -0
  49. src_code_for_reproducibility/models/inference_backend_vllm.py +96 -0
  50. src_code_for_reproducibility/models/inference_backend_vllm_local_server.py +160 -0
src_code_for_reproducibility/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (185 Bytes). View file
 
src_code_for_reproducibility/docs/source/conf.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration file for the Sphinx documentation builder.
2
+ import os
3
+ import sys
4
+ sys.path.insert(0, os.path.abspath('../..'))
5
+
6
+ # -- Project information -----------------------------------------------------
7
+ project = 'llm_negotiation'
8
+ copyright = '2023, Your Name'
9
+ author = 'Your Name'
10
+
11
+ # -- General configuration ---------------------------------------------------
12
+ extensions = [
13
+ 'sphinx.ext.autodoc',
14
+ 'sphinx.ext.viewcode',
15
+ 'sphinx.ext.napoleon',
16
+ 'sphinx.ext.autosummary',
17
+ 'sphinx.ext.intersphinx',
18
+ 'sphinx.ext.mathjax',
19
+ 'sphinxcontrib.mermaid',
20
+ 'sphinx_rtd_theme',
21
+ ]
22
+
23
+ templates_path = ['_templates']
24
+ exclude_patterns = []
25
+
26
+ # -- Options for HTML output -------------------------------------------------
27
+ html_theme = 'sphinx_rtd_theme'
28
+ html_static_path = ['_static']
29
+
30
+ # -- Napoleon settings -------------------------------------------------------
31
+ napoleon_google_docstring = True
32
+ napoleon_numpy_docstring = False
33
+ napoleon_include_init_with_doc = True
34
+ napoleon_include_private_with_doc = False
35
+ napoleon_include_special_with_doc = True
36
+ napoleon_use_admonition_for_examples = False
37
+ napoleon_use_admonition_for_notes = False
38
+ napoleon_use_admonition_for_references = False
39
+ napoleon_use_ivar = False
40
+ napoleon_use_param = True
41
+ napoleon_use_rtype = True
42
+ napoleon_preprocess_types = False
43
+ napoleon_type_aliases = None
44
+ napoleon_attr_annotations = True
45
+
46
+ # -- Path setup --------------------------------------------------------------
47
+ # Make sure the project's modules can be found by Sphinx
48
+ sys.path.insert(0, os.path.abspath('../../src'))
src_code_for_reproducibility/docs/source/environments/diplomacy.rst ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ =================
2
+ Diplomacy
3
+ =================
4
+
5
+ The Diplomacy environment provides a multi-agent negotiation interface for the classic board game Diplomacy,
6
+ based on DeepMind's implementation. This document describes the API for interacting with the Diplomacy environment
7
+ and its associated agent handler.
8
+
9
+ Overview
10
+ --------
11
+
12
+ Diplomacy is a strategic board game set in Europe before World War I, where players control one of seven European powers
13
+ and negotiate with each other to gain control of supply centers. The game is played in turns, with each turn consisting
14
+ of movement phases, retreat phases, and build phases.
15
+
16
+ Our implementation adapts DeepMind's Diplomacy code to the Multi-Agent Negotiation Environment standard, allowing it
17
+ to be used with LLM agents through a text-based interface.
18
+
19
+ Game Rules
20
+ ----------
21
+
22
+ ### Game Board and Powers
23
+
24
+ Diplomacy is played on a map of Europe divided into provinces. The game features seven Great Powers that players can control:
25
+
26
+ - England (blue)
27
+ - France (light blue)
28
+ - Germany (black)
29
+ - Italy (green)
30
+ - Austria-Hungary (red)
31
+ - Russia (white)
32
+ - Turkey (yellow)
33
+
34
+ Each power begins with three supply centers (except Russia, which starts with four) and an equal number of units.
35
+
36
+ ### Units and Movement
37
+
38
+ There are two types of units in Diplomacy:
39
+ - **Armies (A)**: Can move to adjacent land provinces or be convoyed across water by fleets
40
+ - **Fleets (F)**: Can move to adjacent coastal provinces and sea regions
41
+
42
+ During movement phases, each unit can execute one of these orders:
43
+ - **Hold**: The unit remains in its current province (e.g., "A PAR H")
44
+ - Format: [Unit Type] [Province] H
45
+ - Example: "A PAR H" means "Army in Paris holds its position"
46
+
47
+ - **Move**: The unit attempts to move to an adjacent province (e.g., "A PAR - BUR")
48
+ - Format: [Unit Type] [Current Province] - [Destination Province]
49
+ - Example: "A PAR - BUR" means "Army in Paris moves to Burgundy"
50
+ - Example: "F BRE - ENG" means "Fleet in Brest moves to the English Channel"
51
+
52
+ - **Support**: The unit supports another unit's move or hold (e.g., "A PAR S A MAR - BUR")
53
+ - Format for supporting a move: [Unit Type] [Province] S [Unit Type] [Province] - [Destination]
54
+ - Format for supporting a hold: [Unit Type] [Province] S [Unit Type] [Province]
55
+ - Example: "A PAR S A MAR - BUR" means "Army in Paris supports the Army in Marseille's move to Burgundy"
56
+ - Example: "F LON S F NTH" means "Fleet in London supports the Fleet in North Sea holding its position"
57
+
58
+ - **Convoy**: A fleet can convoy an army across water (e.g., "F ENG C A LON - BRE")
59
+ - Format: [Fleet] [Sea Province] C [Army] [Coastal Province] - [Coastal Province]
60
+ - Example: "F ENG C A LON - BRE" means "Fleet in English Channel convoys the Army in London to Brest"
61
+
62
+ All orders are executed simultaneously, and conflicts are resolved based on strength (number of supporting units).
63
+
64
+ ### Common Province Abbreviations
65
+
66
+ Diplomacy uses three-letter abbreviations for provinces. Some common ones include:
67
+ - **PAR**: Paris
68
+ - **LON**: London
69
+ - **BER**: Berlin
70
+ - **MUN**: Munich
71
+ - **BUR**: Burgundy
72
+ - **MAR**: Marseilles
73
+ - **BRE**: Brest
74
+ - **ENG**: English Channel
75
+ - **NTH**: North Sea
76
+ - **VIE**: Vienna
77
+ - **ROM**: Rome
78
+ - **VEN**: Venice
79
+ - **MOW**: Moscow
80
+ - **CON**: Constantinople
81
+
82
+ ### Example: Movement and Conflicts
83
+
84
+ For example, if France orders "A PAR - BUR" and Germany orders "A MUN - BUR", neither move succeeds as they have equal strength. However, if France also orders "A MAR S A PAR - BUR", then the French army from Paris would successfully move to Burgundy with strength of 2 against Germany's strength of 1.
85
+
86
+ ### Turn Structure
87
+
88
+ A game year consists of five phases:
89
+ 1. **Spring Movement**: All powers submit orders for their units
90
+ 2. **Spring Retreat**: Units dislodged in the movement phase must retreat or be disbanded
91
+ 3. **Fall Movement**: Another round of movement orders
92
+ 4. **Fall Retreat**: Retreat orders for dislodged units
93
+ 5. **Winter Adjustment**: Powers gain or lose units based on the number of supply centers they control
94
+
95
+ ### Supply Centers and Building
96
+
97
+ Supply centers (marked on the map) are key to victory. When a power occupies a supply center during a Fall turn, they gain control of it. During the Winter Adjustment phase:
98
+ - If you control more supply centers than you have units, you can build new units in your home supply centers
99
+ - If you control fewer supply centers than you have units, you must remove excess units
100
+
101
+ ### Example: Building and Removing Units
102
+
103
+ If France controls 5 supply centers but only has 4 units, during the Winter phase they can build one new unit in an unoccupied home supply center (Paris, Marseilles, or Brest). Conversely, if France controls only 3 supply centers but has 4 units, they must remove one unit of their choice.
104
+
105
+ ### Negotiation
106
+
107
+ A critical component of Diplomacy is the negotiation between players. Before submitting orders, players can communicate freely to form alliances, coordinate attacks, or mislead opponents. These negotiations are not binding, and betrayal is a common strategy.
108
+
109
+ ### Example: Alliance and Betrayal
110
+
111
+ England and France might agree to an alliance against Germany, with England promising to support France's move into Belgium. However, England could secretly order their fleet to move into Belgium themselves or support a German move instead.
112
+
113
+ ### Victory Conditions
114
+
115
+ The game ends when one power controls 18 or more supply centers (majority of the 34 total centers), or when players agree to a draw. In tournament settings, games may also end after a predetermined number of game years.
116
+
117
+ DiplomacyEnv
118
+ ------------
119
+
120
+ The ``DiplomacyEnv`` class provides an interface to the Diplomacy game environment that follows the Multi-Agent
121
+ Negotiation Environment standard.
122
+
123
+ .. code-block:: python
124
+
125
+ class DiplomacyEnv:
126
+ """
127
+ Multi-Agent Negotiation Environment for Diplomacy, adapting Deepmind's implementation
128
+ to the MarlEnvironment standard.
129
+ """
130
+ def __init__(self,
131
+ initial_state: Optional[DiplomacyState] = None,
132
+ max_turns: int = 100,
133
+ points_per_supply_centre: bool = True,
134
+ forced_draw_probability: float = 0.0,
135
+ min_years_forced_draw: int = 35):
136
+ """Initialize the Diplomacy environment.
137
+
138
+ Args:
139
+ initial_state: Initial DiplomacyState (optional)
140
+ max_turns: Maximum number of turns in the game
141
+ points_per_supply_centre: Whether to award points per supply center in case of a draw
142
+ forced_draw_probability: Probability of forcing a draw after min_years_forced_draw
143
+ min_years_forced_draw: Minimum years before considering a forced draw
144
+ """
145
+ # ...
146
+
147
+ def reset(self):
148
+ """Reset the environment to an initial state and return the initial observation.
149
+
150
+ Returns:
151
+ observation (dict): A dictionary where keys are agent identifiers and values are observations.
152
+ Each observation contains:
153
+ - board_state: Current state of the board
154
+ - current_season: Current season in the game
155
+ - player_index: Index of the player's power
156
+ - possible_actions: List of possible actions in DeepMind's format
157
+ - human_readable_actions: List of human-readable action descriptions
158
+ - supply_centers: List of supply centers owned by the player
159
+ - units: List of units owned by the player
160
+ - year: Current year in the game
161
+ """
162
+ # ...
163
+
164
+ def step(self, actions):
165
+ """Take a step in the environment using the provided actions.
166
+
167
+ Args:
168
+ actions (dict): A dictionary where keys are agent identifiers and values are actions.
169
+ Actions can be:
170
+ - List of integer actions in DeepMind's format
171
+ - List of string actions in text format (e.g., "A MUN - BER")
172
+
173
+ Returns:
174
+ observations (dict): A dictionary where keys are agent identifiers and values are observations.
175
+ Each observation has the same structure as in reset().
176
+ done (bool): Whether the episode has ended.
177
+ info (dict): Additional information about the environment, including:
178
+ - turn: Current turn number
179
+ - returns: Game returns if the game is done, otherwise None
180
+ - waiting_for: List of agents that still need to provide actions (if not all actions are provided)
181
+ """
182
+ # ...
183
+
184
+ def get_log_info(self):
185
+ """Get additional information about the environment for logging.
186
+
187
+ Returns:
188
+ log_info (dict): Information about the environment required to log the game, including:
189
+ - power_names: List of power names
190
+ - game_history: History of the game
191
+ - current_turn: Current turn number
192
+ - current_season: Current season name
193
+ - supply_centers: Dictionary mapping power names to supply center counts
194
+ """
195
+ # ...
196
+
197
+ def render(self):
198
+ """Render the current state of the environment.
199
+
200
+ Displays a visualization of the current game state.
201
+ """
202
+ # ...
203
+
204
+ def close(self):
205
+ """Perform any necessary cleanup."""
206
+ # ...
207
+
208
+
209
+ Key Implementation Details
210
+ ~~~~~~~~~~~~~~~~~~~~~~~~~
211
+
212
+ The ``DiplomacyEnv`` class implements several key features:
213
+
214
+ 1. **Multi-Agent Support**: The environment tracks multiple agents (powers) and manages their interactions.
215
+
216
+ 2. **Turn-Based Gameplay**: The environment enforces the turn structure of Diplomacy, including different phases.
217
+
218
+ 3. **Action Processing**: The environment can handle actions in both text format and DeepMind's integer format.
219
+
220
+ 4. **Observation Generation**: The environment generates detailed observations for each agent, including board state, supply centers, and possible actions.
221
+
222
+ 5. **Game Termination**: The environment tracks game termination conditions, including supply center victory and maximum turn limits.
223
+
224
+ Observation Structure
225
+ ~~~~~~~~~~~~~~~~~~~~
226
+
227
+ Each agent receives an observation dictionary with the following structure:
228
+
229
+ .. code-block:: python
230
+
231
+ {
232
+ "board_state": np.ndarray, # Board state representation
233
+ "current_season": int, # Season index (0-4)
234
+ "player_index": int, # Index of the player's power (0-6)
235
+ "possible_actions": [int], # List of possible actions in DeepMind's format
236
+ "human_readable_actions": [str], # List of human-readable action descriptions
237
+ "supply_centers": [str], # List of supply centers owned by the player
238
+ "units": [dict], # List of units owned by the player
239
+ "year": int # Current year in the game
240
+ }
241
+
242
+ Action Structure
243
+ ~~~~~~~~~~~~~~~
244
+
245
+ Actions can be provided in two formats:
246
+
247
+ 1. **Text Format**: String actions like ``"A MUN - BER"`` or ``"F NTH C A LON - BEL"``.
248
+
249
+ 2. **Integer Format**: Lists of integers corresponding to DeepMind's action representation.
250
+
251
+ The environment will convert text actions to the internal format as needed.
252
+
253
+ DiplomacyAgent
254
+ --------------
255
+
256
+ The ``DiplomacyAgent`` class implements the agent handler interface for Diplomacy, processing observations from the environment and generating actions through an LLM.
257
+
258
+ .. code-block:: python
259
+
260
+ class DiplomacyAgent:
261
+ """
262
+ Agent handler for Diplomacy, implementing the AgentState interface
263
+ for the multi-agent negotiation standard.
264
+ """
265
+
266
+ def __init__(self,
267
+ power_name: str,
268
+ use_text_interface: bool = True,
269
+ system_prompt: Optional[str] = None):
270
+ """Initialize the Diplomacy agent handler.
271
+
272
+ Args:
273
+ power_name: Name of the power this agent controls
274
+ use_text_interface: Whether to use text-based interface (vs. structured)
275
+ system_prompt: Optional system prompt to use for the LLM
276
+ """
277
+ # ...
278
+
279
+ def step(self, observation_from_env, policy_output=None):
280
+ """Update the agent state based on the observation and action.
281
+
282
+ Args:
283
+ observation_from_env: The observation from the environment, with structure:
284
+ - board_state: Current state of the board
285
+ - current_season: Current season in the game
286
+ - player_index: Index of the player's power
287
+ - possible_actions: List of possible actions
288
+ - human_readable_actions: List of human-readable action descriptions
289
+ - supply_centers: List of supply centers owned by the player
290
+ - units: List of units owned by the player
291
+ - year: Current year in the game
292
+
293
+ policy_output: The output of the policy (LLM response), or None for initial prompt
294
+
295
+ Returns:
296
+ policy_id (str): The policy identifier ("llm_policy")
297
+ policy_input (dict): The input to the policy, with structure:
298
+ - messages: List of conversation messages in the format:
299
+ [{"role": "system", "content": "..."},
300
+ {"role": "user", "content": "..."}]
301
+ action: The official action to be sent to the environment, or None if not ready
302
+ done (bool): Whether the LLM action is ready to be sent to the environment
303
+ info (dict): Additional information about the agent:
304
+ - valid_action: Whether the extracted action is valid
305
+ """
306
+ # ...
307
+
308
+ def get_log_info(self):
309
+ """Get information about the agent required to log a trajectory.
310
+
311
+ Returns:
312
+ log_info (dict): Information about the agent required to log a trajectory:
313
+ - power_name: Name of the power this agent controls
314
+ - conversation_history: List of conversation messages
315
+ - current_action: The current action, if any
316
+ """
317
+ # ...
318
+
319
+ def render(self):
320
+ """Render the current state of the agent.
321
+
322
+ Displays the agent's current state, including conversation history.
323
+ """
324
+ # ...
325
+
326
+ def close(self):
327
+ """Perform any necessary cleanup."""
328
+ # ...
329
+
330
+
331
+ Key Implementation Details
332
+ ~~~~~~~~~~~~~~~~~~~~~~~~~
333
+
334
+ The ``DiplomacyAgent`` class implements several key features:
335
+
336
+ 1. **LLM Interaction**: The agent generates prompts for an LLM and processes the LLM's responses to extract actions.
337
+
338
+ 2. **Conversation Management**: The agent maintains a conversation history for coherent interactions with the LLM.
339
+
340
+ 3. **Action Validation**: The agent validates extracted actions against the set of possible actions provided by the environment.
341
+
342
+ 4. **Error Handling**: The agent generates clarification prompts when invalid actions are detected.
343
+
344
+ 5. **Text-Based Interface**: The agent formats game state information into human-readable text for the LLM.
345
+
346
+ Prompt Structure
347
+ ~~~~~~~~~~~~~~~
348
+
349
+ The agent generates prompts that include:
350
+
351
+ 1. **System Prompt**: Instructions and context for the LLM, explaining its role as a Diplomacy player.
352
+
353
+ 2. **Game State Description**: A text description of the current game state, including:
354
+ - Current year and season
355
+ - Supply centers owned
356
+ - Units controlled
357
+ - Possible actions
358
+
359
+ 3. **Action Request**: Instructions on how to format actions.
360
+
361
+ Example system prompt:
362
+
363
+ .. code-block:: text
364
+
365
+ You are playing the role of FRANCE in a game of Diplomacy.
366
+ Your goal is to control as many supply centers as possible.
367
+ You can negotiate with other players and form alliances, but remember that
368
+ these alliances are not binding. When you need to submit orders for your units,
369
+ write them in the correct format, with each order on a new line.
370
+
371
+ Example game state description:
372
+
373
+ .. code-block:: text
374
+
375
+ Year: 1901, Season: SPRING_MOVES
376
+ You are playing as FRANCE.
377
+ You currently control 3 supply centers: PAR, MAR, BRE.
378
+ Your units are: A PAR, A MAR, F BRE.
379
+
380
+ Please provide orders for your units. Here are your possible actions:
381
+ A PAR - BUR
382
+ A PAR - GAS
383
+ A PAR - PIC
384
+ A PAR H
385
+ ...
386
+
387
+ Submit your orders, one per line, in the format like: "A MUN - BER" or "F NTH C A LON - BEL"
388
+
389
+ Running Diplomacy Games
390
+ ----------------------
391
+
392
+ To run Diplomacy games with LLM agents, you can use the ``run_batched_matches`` function with the ``DiplomacyEnv`` and ``DiplomacyAgent`` classes:
393
+
394
+ .. code-block:: python
395
+
396
+ from mllm.environments.diplomacy.diplomacy_env import DiplomacyEnv
397
+ from mllm.environments.diplomacy.diplomacy_agent import DiplomacyAgent
398
+ from mllm.run_matches import run_batched_matches
399
+
400
+ # Create environment and agent handlers
401
+ env = DiplomacyEnv(max_turns=30)
402
+
403
+ agent_handlers = {
404
+ "AUSTRIA": DiplomacyAgent(power_name="AUSTRIA"),
405
+ "ENGLAND": DiplomacyAgent(power_name="ENGLAND"),
406
+ "FRANCE": DiplomacyAgent(power_name="FRANCE"),
407
+ "GERMANY": DiplomacyAgent(power_name="GERMANY"),
408
+ "ITALY": DiplomacyAgent(power_name="ITALY"),
409
+ "RUSSIA": DiplomacyAgent(power_name="RUSSIA"),
410
+ "TURKEY": DiplomacyAgent(power_name="TURKEY")
411
+ }
412
+
413
+ # Define policy mapping (mapping from policy IDs to actual policy functions)
414
+ policy_mapping = {
415
+ "llm_policy": my_llm_policy_function
416
+ }
417
+
418
+ # Run the game
419
+ game_results = run_batched_matches(
420
+ envs=[env],
421
+ agent_handlers_per_env=[agent_handlers],
422
+ policy_mapping=policy_mapping,
423
+ max_parallel_matches=1
424
+ )
425
+
426
+ # Process results
427
+ for result in game_results:
428
+ print(f"Game finished. Winner: {result['winner']}")
429
+ print(f"Supply centers: {result['supply_centers']}")
430
+
431
+ This setup allows you to run Diplomacy games with LLM agents using the Multi-Agent Negotiation Environment standard.
432
+
433
+ Limitations and Considerations
434
+ -----------------------------
435
+
436
+ 1. **Performance**: Processing observations and actions for seven powers using LLMs can be computationally intensive.
437
+
438
+ 2. **Action Parsing**: Extracting valid actions from LLM outputs may require sophisticated parsing and error handling.
439
+
440
+ 3. **Game Complexity**: Diplomacy is a complex game with many rules and edge cases, which may be challenging for LLMs to fully grasp.
441
+
442
+ 4. **Turn Duration**: Real Diplomacy games include negotiation phases of variable duration, which are not fully captured in this implementation.
443
+
444
+ 5. **Text Formatting**: The quality of LLM interactions depends heavily on the formatting and clarity of text prompts.
445
+
446
+ Advanced Usage
447
+ ------------
448
+
449
+ For advanced usage, you can customize:
450
+
451
+ 1. **System Prompts**: Modify agent behavior by providing custom system prompts.
452
+
453
+ 2. **Observation Processing**: Extend the observation processing to include additional information.
454
+
455
+ 3. **Action Parsing**: Implement more sophisticated action parsing for complex orders.
456
+
457
+ 4. **Visualization**: Add custom visualization methods to the environment's render function.
458
+
459
+ 5. **Logging**: Extend the logging capabilities to capture additional information about the game state.
src_code_for_reproducibility/docs/source/environments/dond.rst ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ =================
2
+ Deal or No Deal
3
+ =================
4
+
5
+ The Deal or No Deal (DoND) environment provides a multi-agent negotiation interface where players trade
6
+ items with different values. This document describes the API for interacting with the DoND environment
7
+ and its associated agent handler.
8
+
9
+ Overview
10
+ --------
11
+
12
+ Deal or No Deal is a negotiation game where two agents must agree on how to divide a set of items,
13
+ each of which has different values to each agent. The agents engage in a back-and-forth dialogue to
14
+ determine an allocation of the items, with each trying to maximize their own total value.
15
+
16
+ Our implementation follows the Multi-Agent Negotiation Environment standard, allowing it to be used
17
+ with LLM agents through a text-based interface.
18
+
19
+ Game Rules
20
+ ----------
21
+
22
+ ### Basic Structure
23
+
24
+ The core mechanics of Deal or No Deal are:
25
+
26
+ 1. Two agents negotiate over a set of items (e.g., books, balls, hats)
27
+ 2. Each item has:
28
+ - A specific quantity (how many of each item is available)
29
+ - A value for each agent (which may differ between agents)
30
+ 3. Agents take turns sending messages to negotiate how to split the items
31
+ 4. Once an agreement is reached, agents finalize the deal
32
+ 5. Points are awarded based on the value of items each agent receives
33
+
34
+ ### Detailed Gameplay
35
+
36
+ #### Setup Phase
37
+
38
+ The game begins with:
39
+ - A set of items (e.g., "book", "hat", "ball")
40
+ - Each item has a quantity (e.g., 6 books, 2 hats, 4 balls)
41
+ - Each agent has private values for each item (e.g., books might be worth 5 points to one agent but only 2 points to the other)
42
+ - Agents are assigned roles (starting negotiator and responding negotiator)
43
+
44
+ #### Negotiation Phase
45
+
46
+ 1. Agents take turns sending free-form text messages to each other
47
+ 2. Messages can include offers, counter-offers, questions, or strategic communication
48
+ 3. There is a maximum number of messages permitted (preventing endless negotiations)
49
+ 4. Either agent can propose to finalize an agreement at any time
50
+
51
+ For example:
52
+ - Agent 1: "I propose I get all the books and you get all the hats and balls."
53
+ - Agent 2: "That doesn't work for me. How about you get 3 books and I get 3 books, all the hats, and all the balls?"
54
+ - Agent 1: "Let me counter-offer: I get 4 books and 2 balls, you get 2 books, all hats, and 2 balls."
55
+
56
+ #### Finalization Phase
57
+
58
+ 1. When an agent wants to finalize a deal, they must specify the exact allocation:
59
+ - How many of each item they receive
60
+ - How many of each item the other agent receives
61
+ 2. The other agent must then either agree (by submitting the same allocation) or reject the finalization
62
+ 3. If both agents submit matching finalizations, the deal is executed
63
+ 4. If finalizations don't match, no agreement is reached, and both agents receive 0 points
64
+
65
+ #### Scoring
66
+
67
+ 1. Each agent's score is calculated based on the value of items they receive
68
+ 2. The formula is: Sum(quantity_of_item_i × value_of_item_i_to_agent)
69
+ 3. If no agreement is reached, both agents receive 0 points
70
+
71
+ ### Example Game
72
+
73
+ Let's walk through a simple example:
74
+
75
+ **Setup:**
76
+ - Items: Books (4), Hats (2), Balls (6)
77
+ - Agent 1 values: Books=5, Hats=1, Balls=2
78
+ - Agent 2 values: Books=3, Hats=6, Balls=1
79
+
80
+ **Negotiation (simplified):**
81
+ 1. Agent 1: "I would like all the books and balls. You can have the hats."
82
+ 2. Agent 2: "That doesn't work for me. Books are valuable. I propose I get all the hats and 2 books, you get 2 books and all the balls."
83
+ 3. Agent 1: "How about I get 3 books and all the balls, and you get 1 book and all the hats?"
84
+ 4. Agent 2: "I accept your proposal."
85
+
86
+ **Finalization:**
87
+ - Agent 1 submits: Agent 1 gets (Books: 3, Hats: 0, Balls: 6), Agent 2 gets (Books: 1, Hats: 2, Balls: 0)
88
+ - Agent 2 submits the same allocation, confirming agreement
89
+
90
+ **Scoring:**
91
+ - Agent 1 score: (3 books × 5) + (0 hats × 1) + (6 balls × 2) = 15 + 0 + 12 = 27 points
92
+ - Agent 2 score: (1 book × 3) + (2 hats × 6) + (0 balls × 1) = 3 + 12 + 0 = 15 points
93
+
94
+ ### Game Variations
95
+
96
+ The DoND environment supports several variations through configuration parameters:
97
+
98
+ #### Different Value Distributions
99
+
100
+ The environment offers multiple ways to assign values to items:
101
+
102
+ 1. **Standard Random Setup (dond_random_setup)**:
103
+ - Items have even-numbered quantities
104
+ - Each agent receives distinct random values for each item
105
+ - Values are drawn from a uniform distribution
106
+
107
+ 2. **Independent Random Values (independent_random_vals)**:
108
+ - Item quantities can be any number in the specified range
109
+ - Values for each agent are drawn independently
110
+ - Creates more varied negotiation scenarios
111
+
112
+ 3. **Bicameral Value Distribution (bicameral_vals_assignator)**:
113
+ - Creates a "high value" and "low value" distribution for each item
114
+ - Each agent values approximately half the items highly and half lowly
115
+ - Values are drawn from normal distributions with different means
116
+ - Creates scenarios with clear trade opportunities
117
+
118
+ #### Visibility Options
119
+
120
+ 1. **Finalization Visibility**:
121
+ - When enabled, both agents can see each other's finalization proposals
122
+ - When disabled, finalization proposals remain private until both are submitted
123
+
124
+ 2. **Other Values Visibility**:
125
+ - When enabled, agents can see each other's value functions
126
+ - When disabled, agents only know their own values
127
+ - Creates information asymmetry and richer negotiation dynamics
128
+
129
+ #### Game Modes
130
+
131
+ 1. **Cooperative Mode ("coop")**:
132
+ - Agents are encouraged to find mutually beneficial solutions
133
+ - Success is measured by the sum of both agents' scores
134
+
135
+ 2. **Competitive Mode ("comp")**:
136
+ - Agents aim to maximize their individual scores
137
+ - Creates more adversarial negotiations
138
+
139
+ #### Round Structure
140
+
141
+ 1. **Single Round**:
142
+ - One negotiation session between the same agents
143
+ - Simple evaluation of negotiation skills
144
+
145
+ 2. **Multiple Rounds**:
146
+ - Agents negotiate multiple times with different item setups
147
+ - Allows for learning and adaptation over time
148
+ - Roles can be swapped between rounds
149
+
150
+ DondEnv
151
+ ------------
152
+
153
+ The ``DondEnv`` class provides an interface to the Deal or No Deal environment that follows the Multi-Agent
154
+ Negotiation Environment standard.
155
+
156
+ .. code-block:: python
157
+
158
+ class DondEnv:
159
+ """
160
+ Multi-Agent Negotiation Environment for Deal or No Deal.
161
+ """
162
+ def __init__(
163
+ self,
164
+ agents,
165
+ mode="coop",
166
+ max_messages=None,
167
+ min_messages=None,
168
+ max_chars_per_message=None,
169
+ rounds_per_game=1,
170
+ random_setup_func=None,
171
+ random_setup_kwargs=None,
172
+ role_assignator_func=None,
173
+ role_assignator_func_kwargs=None,
174
+ finalization_visibility=False,
175
+ other_values_visibility=False,
176
+ random_seed=None
177
+ ):
178
+ """Initialize the Deal or No Deal environment.
179
+
180
+ Args:
181
+ agents: List of agent IDs participating in the game
182
+ mode: Game mode ("coop" or "comp")
183
+ max_messages: Maximum number of messages per agent per round
184
+ min_messages: Minimum number of messages per agent per round
185
+ max_chars_per_message: Maximum characters per message
186
+ rounds_per_game: Number of negotiation rounds to play
187
+ random_setup_func: Function to generate item quantities and values
188
+ random_setup_kwargs: Arguments for the random setup function
189
+ role_assignator_func: Function to assign roles to agents
190
+ role_assignator_func_kwargs: Arguments for the role assignator
191
+ finalization_visibility: Whether agents can see each other's finalizations
192
+ other_values_visibility: Whether agents can see each other's values
193
+ random_seed: Seed for reproducibility
194
+ """
195
+ # ...
196
+
197
+ def reset(self):
198
+ """Reset the environment to an initial state and return the initial observation.
199
+
200
+ Returns:
201
+ observation (dict): A dictionary where keys are agent identifiers and values are observations.
202
+ """
203
+ # ...
204
+
205
+ def step(self, actions):
206
+ """Take a step in the environment using the provided actions.
207
+
208
+ Args:
209
+ actions (dict): A dictionary where keys are agent identifiers and values are actions.
210
+ Actions can be messages or finalization proposals.
211
+
212
+ Returns:
213
+ observations (dict): A dictionary where keys are agent identifiers and values are observations.
214
+ done (bool): Whether the episode has ended.
215
+ info (dict): Additional information about the environment.
216
+ """
217
+ # ...
218
+
219
+ def get_state(self):
220
+ """Retrieve the current state of the game.
221
+
222
+ Returns:
223
+ state (dict): The current state of the game, including items, quantities, values, etc.
224
+ """
225
+ # ...
226
+
227
+ Key Implementation Details
228
+ ~~~~~~~~~~~~~~~~~~~~~~~~~
229
+
230
+ The ``DondEnv`` class implements several key features:
231
+
232
+ 1. **Multi-Agent Support**: The environment tracks two agents and manages their alternating messages.
233
+
234
+ 2. **Turn-Based Dialogue**: The environment enforces turn structure and limits on message count.
235
+
236
+ 3. **Finalization Processing**: The environment validates and processes finalization proposals.
237
+
238
+ 4. **Random Setup**: The environment supports multiple methods of generating negotiation scenarios.
239
+
240
+ 5. **Round Management**: The environment can handle multiple rounds with different setups.
241
+
242
+ Observation Structure
243
+ ~~~~~~~~~~~~~~~~~~~~
244
+
245
+ Each agent receives an observation (state) dictionary with rich information about the game:
246
+
247
+ .. code-block:: python
248
+
249
+ {
250
+ "mode": str, # Game mode ("coop" or "comp")
251
+ "role_values": dict, # Value mappings for each role
252
+ "role_props": dict, # Properties for each role
253
+ "agent_to_role": dict, # Mapping from agent IDs to roles
254
+ "is_new_round": bool, # Whether this is the start of a new round
255
+ "is_new_game": bool, # Whether this is the start of a new game
256
+ "game_over": bool, # Whether the game is over
257
+ "items": list, # List of item names
258
+ "quantities": dict, # Quantities of each item
259
+ "has_finalized": bool, # Whether finalization has been proposed
260
+ "last_message": dict, # The last message sent
261
+ "messages_remaining": dict, # Number of messages each agent can still send
262
+ # And various history tracking fields
263
+ }
264
+
265
+ Action Structure
266
+ ~~~~~~~~~~~~~~~
267
+
268
+ Actions can be:
269
+
270
+ 1. **Text Messages**: Free-form text for negotiation.
271
+ 2. **Finalization Proposals**: Structured data specifying the exact allocation of items.
272
+
273
+ Example finalization format:
274
+
275
+ .. code-block:: python
276
+
277
+ {
278
+ "type": "finalize",
279
+ "allocation": {
280
+ "agent1": {"book": 3, "hat": 0, "ball": 6},
281
+ "agent2": {"book": 1, "hat": 2, "ball": 0}
282
+ }
283
+ }
284
+
285
+ Value Setup Functions
286
+ --------------------
287
+
288
+ The DoND environment provides several functions for setting up item values:
289
+
290
+ .. code-block:: python
291
+
292
+ def dond_random_setup(items, min_quant, max_quant, min_val, max_val, random_seed=None):
293
+ """
294
+ Generates items, even-numbered quantities and distinct random values for each category for both agents.
295
+
296
+ Args:
297
+ items (list): List of items.
298
+ min_quant (int): Minimum quantity per item.
299
+ max_quant (int): Maximum quantity per item.
300
+ min_val (int): Minimum value per item.
301
+ max_val (int): Maximum value per item.
302
+ random_seed (int, optional): Seed for random generation.
303
+
304
+ Returns:
305
+ tuple: (items, quantities, (val_starting_negotiator, val_responding_negotiator))
306
+ """
307
+ # ...
308
+
309
+ def independent_random_vals(items, min_quant, max_quant, min_val, max_val, random_seed=None):
310
+ """
311
+ Generates random quantities and independent random values for both agents.
312
+
313
+ Args:
314
+ Similar to dond_random_setup
315
+
316
+ Returns:
317
+ tuple: (items, quantities, (val_starting_negotiator, val_responding_negotiator))
318
+ """
319
+ # ...
320
+
321
+ def bicameral_vals_assignator(items, min_quant, max_quant, low_val_mean, low_val_std, high_val_mean, high_val_std, random_seed=None):
322
+ """
323
+ Generates values with a bicameral distribution - each agent values half the items highly.
324
+
325
+ Args:
326
+ items (list): List of items.
327
+ min_quant, max_quant: Range for quantities
328
+ low_val_mean, low_val_std: Mean and standard deviation for the "low value" distribution
329
+ high_val_mean, high_val_std: Mean and standard deviation for the "high value" distribution
330
+ random_seed: Seed for reproducibility
331
+
332
+ Returns:
333
+ tuple: (items, quantities, (val_starting_negotiator, val_responding_negotiator))
334
+ """
335
+ # ...
336
+
337
+ Running DoND Games
338
+ ----------------------
339
+
340
+ To run Deal or No Deal games with LLM agents, you can use the following structure:
341
+
342
+ .. code-block:: python
343
+
344
+ from mllm.environments.dond.dond_game import DondEnv
345
+ from mllm.environments.dond.dond_agent import DondAgent
346
+ from src.run_matches import run_batched_matches
347
+
348
+ # Create environment
349
+ env = DondEnv(
350
+ agents=["agent1", "agent2"],
351
+ mode="coop",
352
+ max_messages=10,
353
+ rounds_per_game=1,
354
+ random_setup_func="dond_random_setup",
355
+ random_setup_kwargs={
356
+ "items": ["book", "hat", "ball"],
357
+ "min_quant": 2,
358
+ "max_quant": 8,
359
+ "min_val": 1,
360
+ "max_val": 10
361
+ },
362
+ finalization_visibility=False
363
+ )
364
+
365
+ # Create agent handlers (implementation details would vary)
366
+ agent_handlers = {
367
+ "agent1": DondAgent(agent_id="agent1"),
368
+ "agent2": DondAgent(agent_id="agent2")
369
+ }
370
+
371
+ # Define policy mapping
372
+ policy_mapping = {
373
+ "llm_policy": my_llm_policy_function
374
+ }
375
+
376
+ # Run the game
377
+ game_results = run_batched_matches(
378
+ envs=[env],
379
+ agent_handlers_per_env=[agent_handlers],
380
+ policy_mapping=policy_mapping,
381
+ max_parallel_matches=1
382
+ )
383
+
384
+ Limitations and Considerations
385
+ -----------------------------
386
+
387
+ 1. **Negotiation Complexity**: The open-ended nature of negotiations can be challenging for some LLM agents.
388
+
389
+ 2. **Parsing Challenges**: Extracting structured finalization proposals from free-form text requires robust parsing.
390
+
391
+ 3. **Optimization Opportunities**: Different agents may employ different negotiation strategies to optimize outcomes.
392
+
393
+ 4. **Fairness Evaluation**: The environment allows research into questions of fair division and Pareto optimality.
394
+
395
+ 5. **Strategic Deception**: Agents might strategically misrepresent their true values, adding complexity to negotiations.
396
+
397
+ Advanced Usage
398
+ ------------
399
+
400
+ For advanced usage, you can:
401
+
402
+ 1. **Custom Value Functions**: Create more complex distributions of item values for specific research questions.
403
+
404
+ 2. **Novel Negotiation Scenarios**: Design item sets and values to test specific negotiation skills.
405
+
406
+ 3. **Curriculum Learning**: Create progressively more difficult negotiation scenarios.
407
+
408
+ 4. **Communication Analysis**: Analyze the language and strategies used in successful negotiations.
409
+
410
+ 5. **Multi-Round Dynamics**: Study how agents adapt their strategies over multiple rounds.
src_code_for_reproducibility/docs/source/environments/ipd.rst ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ =================
2
+ Iterated Prisoner's Dilemma
3
+ =================
4
+
5
+ The Iterated Prisoner's Dilemma environment provides a classic game theory setting for studying cooperation
6
+ and competition between agents. This document describes the API for interacting with the IPD environment
7
+ and its associated agent handler.
8
+
9
+ Overview
10
+ --------
11
+
12
+ The Prisoner's Dilemma is a fundamental problem in game theory that demonstrates why two rational individuals might not
13
+ cooperate, even when it appears in their best interest to do so. In the iterated version, the same two players
14
+ repeatedly face the same dilemma, allowing for the development of trust or retaliation based on previous interactions.
15
+
16
+ Our implementation follows the Multi-Agent Negotiation Environment standard, allowing it to be used with
17
+ LLM agents through a text-based interface.
18
+
19
+ Game Rules
20
+ ----------
21
+
22
+ ### Basic Premise
23
+
24
+ The scenario behind the Prisoner's Dilemma is as follows:
25
+
26
+ Two criminals are arrested and imprisoned. Each prisoner is in solitary confinement with no means of communicating with
27
+ the other. The prosecutors lack sufficient evidence to convict the pair on the principal charge, but they have enough
28
+ to convict both on a lesser charge. Simultaneously, the prosecutors offer each prisoner a bargain:
29
+
30
+ - If both prisoners betray each other, each serves 2 years in prison (the "punishment" payoff)
31
+ - If one betrays the other while the other remains silent, the betrayer goes free (the "temptation" payoff) while the
32
+ silent accomplice serves 3 years (the "sucker" payoff)
33
+ - If both remain silent, each serves only 1 year in prison (the "reward" payoff)
34
+
35
+ ### Game Mechanics
36
+
37
+ In our implementation, the choices are simplified to:
38
+ - **C**: Cooperate (remain silent)
39
+ - **D**: Defect (betray the other prisoner)
40
+
41
+ Each round, both players simultaneously choose either C or D, and receive points based on the combination of their choices:
42
+
43
+ - Both choose C: Both receive the "reward" payoff (3 points by default)
44
+ - Both choose D: Both receive the "punishment" payoff (1 point by default)
45
+ - One chooses C, one chooses D: The defector receives the "temptation" payoff (5 points by default), while the cooperator
46
+ receives the "sucker" payoff (0 points by default)
47
+
48
+ ### Example: Single Round
49
+
50
+ Let's see how a single round plays out:
51
+
52
+ 1. Alice and Bob simultaneously make their choices
53
+ 2. If Alice chooses C and Bob chooses C:
54
+ - Alice receives 3 points
55
+ - Bob receives 3 points
56
+ 3. If Alice chooses C and Bob chooses D:
57
+ - Alice receives 0 points
58
+ - Bob receives 5 points
59
+ 4. If Alice chooses D and Bob chooses C:
60
+ - Alice receives 5 points
61
+ - Bob receives 0 points
62
+ 5. If Alice chooses D and Bob chooses D:
63
+ - Alice receives 1 point
64
+ - Bob receives 1 point
65
+
66
+ ### Iterated Game Structure
67
+
68
+ The iterated version repeats this basic game for a fixed number of rounds. The key features are:
69
+
70
+ 1. Players know the total number of rounds in advance
71
+ 2. After each round, players learn what choice the other player made
72
+ 3. Players maintain a cumulative score across all rounds
73
+ 4. Players can adjust their strategy based on the history of previous interactions
74
+
75
+ ### Game Variations
76
+
77
+ The IPD environment supports several variations through configuration parameters:
78
+
79
+ #### Different Payoff Matrices
80
+
81
+ The standard payoff values can be modified to create different incentive structures:
82
+ - **Traditional PD**: reward=3, punishment=1, temptation=5, sucker=0
83
+ - **Weak Temptation**: reward=3, punishment=1, temptation=4, sucker=0 (reduces the incentive to defect)
84
+ - **Harsh Punishment**: reward=3, punishment=0, temptation=5, sucker=0 (increases the cost of mutual defection)
85
+ - **Generous**: reward=4, punishment=2, temptation=5, sucker=1 (cushions the blow of being betrayed)
86
+
87
+ #### Game Length Variations
88
+
89
+ The number of rounds can significantly impact strategy:
90
+ - **Short Games** (5-10 rounds): Incentivizes more defection, especially near the end
91
+ - **Medium Games** (20-50 rounds): Allows for the development of tit-for-tat and forgiveness strategies
92
+ - **Long Games** (100+ rounds): Favors steady cooperation with occasional "probing" defections
93
+
94
+ ### Common Strategies
95
+
96
+ While not enforced by the environment, several well-known strategies can emerge:
97
+ - **Always Cooperate**: Always choose C
98
+ - **Always Defect**: Always choose D
99
+ - **Tit for Tat**: Start with C, then copy what the opponent did in the previous round
100
+ - **Forgiving Tit for Tat**: Like Tit for Tat, but occasionally cooperate even after being defected against
101
+ - **Grudger**: Cooperate until the opponent defects once, then always defect
102
+ - **Random**: Choose randomly between C and D
103
+
104
+ IPDEnv
105
+ ------
106
+
107
+ The ``IPDEnv`` class provides an interface to the Iterated Prisoner's Dilemma environment that follows the
108
+ Multi-Agent Negotiation Environment standard.
109
+
110
+ .. code-block:: python
111
+
112
+ class IPDEnv:
113
+ """
114
+ Iterated Prisoner's Dilemma environment following the MarlEnvironment standard.
115
+
116
+ In each round of the game, two agents simultaneously choose to either cooperate (C) or defect (D).
117
+ The payoffs are as follows:
118
+ - If both cooperate: Both receive the "reward" (usually 3 points)
119
+ - If both defect: Both receive the "punishment" (usually 1 point)
120
+ - If one cooperates and one defects: The defector receives the "temptation" (usually 5 points)
121
+ and the cooperator receives the "sucker" payoff (usually 0 points)
122
+
123
+ The game is played for a specified number of rounds.
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ rounds_per_game: int = 10,
129
+ reward: float = 3.0, # Both cooperate
130
+ punishment: float = 1.0, # Both defect
131
+ temptation: float = 5.0, # Defector's reward when other cooperates
132
+ sucker: float = 0.0, # Cooperator's reward when other defects
133
+ random_seed: Optional[int] = None,
134
+ ):
135
+ """
136
+ Initialize the Iterated Prisoner's Dilemma environment.
137
+
138
+ Args:
139
+ rounds_per_game: Number of rounds to play
140
+ reward: Payoff when both agents cooperate
141
+ punishment: Payoff when both agents defect
142
+ temptation: Payoff for defecting when other agent cooperates
143
+ sucker: Payoff for cooperating when other agent defects
144
+ seed: Random seed for reproducibility
145
+ """
146
+ # ...
147
+
148
+ def reset(self) -> Dict[str, Dict[str, Any]]:
149
+ """
150
+ Reset the environment to an initial state and return the initial observation.
151
+
152
+ Returns:
153
+ observation (dict): A dictionary where keys are agent identifiers and values are observations.
154
+ """
155
+ # ...
156
+
157
+ def step(self, actions: Dict[str, str]) -> Tuple[Dict[str, Dict[str, Any]], bool, Dict[str, Any]]:
158
+ """
159
+ Take a step in the environment using the provided actions.
160
+
161
+ Args:
162
+ actions (dict): A dictionary where keys are agent identifiers and values are actions ('C' or 'D').
163
+
164
+ Returns:
165
+ observations (dict): A dictionary where keys are agent identifiers and values are observations.
166
+ done (bool): Whether the episode has ended.
167
+ info (dict): Additional information about the environment.
168
+ """
169
+ # ...
170
+
171
+ Key Implementation Details
172
+ ~~~~~~~~~~~~~~~~~~~~~~~~~
173
+
174
+ The ``IPDEnv`` class implements several key features:
175
+
176
+ 1. **Two-Agent Support**: The environment tracks two agents ("alice" and "bob") and manages their interactions.
177
+
178
+ 2. **Round-Based Play**: The environment enforces turn structure and tracks game history.
179
+
180
+ 3. **Payoff Matrix**: The environment calculates rewards based on the standard prisoner's dilemma payoff matrix.
181
+
182
+ 4. **Observation Generation**: The environment generates detailed observations for each agent, including action history and rewards.
183
+
184
+ 5. **Game Termination**: The environment tracks game termination after the specified number of rounds.
185
+
186
+ Observation Structure
187
+ ~~~~~~~~~~~~~~~~~~~~
188
+
189
+ Each agent receives an observation dictionary with the following structure:
190
+
191
+ .. code-block:: python
192
+
193
+ {
194
+ "current_round": int, # Current round number (0-indexed)
195
+ "rounds_per_game": int, # Total number of rounds in the game
196
+ "history": List[Dict], # Complete game history so far
197
+ "last_round_actions": Dict[str, str], # Actions from the previous round (if any)
198
+ "last_round_reward": float, # Reward received in the previous round (if any)
199
+ "total_reward": float, # Cumulative reward so far
200
+ "payoff_matrix": Dict[str, float], # The game's payoff matrix values
201
+ }
202
+
203
+ Action Structure
204
+ ~~~~~~~~~~~~~~~
205
+
206
+ Actions are simple strings:
207
+
208
+ 1. ``"C"`` for Cooperate
209
+ 2. ``"D"`` for Defect
210
+
211
+ IPDAgent
212
+ --------------
213
+
214
+ The ``IPDAgent`` class implements the agent handler interface for the Iterated Prisoner's Dilemma, processing observations from the environment and generating actions through an LLM.
215
+
216
+ .. code-block:: python
217
+
218
+ class IPDAgent:
219
+ """
220
+ Agent handler for Iterated Prisoner's Dilemma, implementing the AgentState interface
221
+ for the multi-agent negotiation standard.
222
+ """
223
+
224
+ def __init__(
225
+ self,
226
+ agent_id: str,
227
+ policy_id: str = "llm_policy",
228
+ system_prompt: Optional[str] = None,
229
+ max_errors: int = 3,
230
+ opponent_id: Optional[str] = None,
231
+ ):
232
+ """
233
+ Initialize the IPD agent handler.
234
+
235
+ Args:
236
+ agent_id: Identifier for this agent ("alice" or "bob")
237
+ policy_id: Identifier for the policy this agent uses
238
+ system_prompt: Optional custom system prompt for the LLM
239
+ max_errors: Maximum number of parsing errors before defaulting to cooperate
240
+ opponent_id: Optional identifier of the opponent (inferred if not provided)
241
+ """
242
+ # ...
243
+
244
+ def step(self, observation_from_env: Dict[str, Any], policy_output: str = None) -> Tuple[str, Dict[str, Any], str, bool, Dict[str, Any]]:
245
+ """
246
+ Update the agent state based on the observation and process the policy output.
247
+
248
+ Args:
249
+ observation_from_env: The observation from the environment
250
+ policy_output: The output from the policy (LLM response)
251
+
252
+ Returns:
253
+ policy_id: The policy identifier
254
+ policy_input: The input to the policy
255
+ action: The action to be sent to the environment
256
+ done: Whether the action is ready to be sent to the environment
257
+ info: Additional information about the agent
258
+ """
259
+ # ...
260
+
261
+ Key Implementation Details
262
+ ~~~~~~~~~~~~~~~~~~~~~~~~~
263
+
264
+ The ``IPDAgent`` class implements several key features:
265
+
266
+ 1. **LLM Interaction**: The agent generates prompts for an LLM and processes the LLM's responses.
267
+
268
+ 2. **Action Extraction**: The agent parses the LLM's output to extract valid actions (C or D).
269
+
270
+ 3. **Error Handling**: The agent provides helpful error messages when parsing fails and defaults to cooperation after multiple failures.
271
+
272
+ 4. **History Tracking**: The agent maintains and provides the complete game history in its prompts.
273
+
274
+ 5. **Strategy Explanation**: The agent can extract and log the reasoning behind an LLM's decisions.
275
+
276
+ Prompt Structure
277
+ ~~~~~~~~~~~~~~~
278
+
279
+ The agent generates prompts that include:
280
+
281
+ 1. **System Prompt**: Instructions and context for the LLM, explaining its role and the rules of the Prisoner's Dilemma.
282
+
283
+ 2. **Game State Description**: A text description of the current game state, including:
284
+ - Current round number
285
+ - History of previous rounds (if any)
286
+ - Cumulative score
287
+
288
+ 3. **Action Request**: Instructions on how to format the response, requiring an explicit action tag.
289
+
290
+ Example system prompt:
291
+
292
+ .. code-block:: text
293
+
294
+ You are playing as Alice in an Iterated Prisoner's Dilemma game against Bob.
295
+ In each round, you must choose to either Cooperate (C) or Defect (D).
296
+
297
+ The payoffs are:
298
+ - If both players Cooperate: You each get 3 points
299
+ - If both players Defect: You each get 1 point
300
+ - If you Cooperate and Bob Defects: You get 0 points, Bob gets 5 points
301
+ - If you Defect and Bob Cooperates: You get 5 points, Bob gets 0 points
302
+
303
+ Your goal is to maximize your total points across all rounds.
304
+ The game will last for exactly 10 rounds, and both players know this.
305
+
306
+ Example game state prompt:
307
+
308
+ .. code-block:: text
309
+
310
+ Current round: 3/10
311
+
312
+ History:
313
+ Round 1: You chose C, Bob chose C. You earned 3 points.
314
+ Round 2: You chose C, Bob chose D. You earned 0 points.
315
+
316
+ Your total score so far: 3 points
317
+
318
+ What is your choice for round 3?
319
+ Please respond with <action>C</action> to cooperate or <action>D</action> to defect,
320
+ and explain your reasoning.
321
+
322
+ Running IPD Games
323
+ ----------------------
324
+
325
+ To run Iterated Prisoner's Dilemma games with LLM agents, you can use the following code structure:
326
+
327
+ .. code-block:: python
328
+
329
+ from mllm.environments.ipd.ipd_game import IPDEnv
330
+ from mllm.environments.ipd.ipd_agent import IPDAgent
331
+ from mllm.run_matches import run_batched_matches
332
+
333
+ # Create environment
334
+ env = IPDEnv(
335
+ rounds_per_game=10,
336
+ reward=3.0,
337
+ punishment=1.0,
338
+ temptation=5.0,
339
+ sucker=0.0
340
+ )
341
+
342
+ # Create agent handlers
343
+ agent_handlers = {
344
+ "alice": IPDAgent(agent_id="alice"),
345
+ "bob": IPDAgent(agent_id="bob")
346
+ }
347
+
348
+ # Define policy mapping
349
+ policy_mapping = {
350
+ "llm_policy": my_llm_policy_function
351
+ }
352
+
353
+ # Run the game
354
+ game_results = run_batched_matches(
355
+ envs=[env],
356
+ agent_handlers_per_env=[agent_handlers],
357
+ policy_mapping=policy_mapping,
358
+ max_parallel_matches=1
359
+ )
360
+
361
+ # Process results
362
+ for result in game_results:
363
+ print(f"Game finished. Scores: {result['total_rewards']}")
364
+
365
+ Statistics and Analysis
366
+ ----------------------
367
+
368
+ The IPD environment includes utility functions for analyzing game outcomes:
369
+
370
+ 1. **Cooperation Rates**: Percentage of rounds where each agent cooperated.
371
+ 2. **Mutual Cooperation/Defection**: Percentage of rounds where both agents made the same choice.
372
+ 3. **Score Distribution**: Analysis of how points were accumulated over the game.
373
+
374
+ These statistics can be calculated using the ``gather_ipd_statistics`` function:
375
+
376
+ .. code-block:: python
377
+
378
+ from mllm.environments.ipd.ipd_statistics_funcs import gather_ipd_statistics
379
+
380
+ stats = gather_ipd_statistics(match_info, env_info)
381
+ print(f"Cooperation rates: {stats['cooperation_rate']}")
382
+ print(f"Mutual cooperation rate: {stats['mutual_cooperation_rate']}")
383
+ print(f"Mutual defection rate: {stats['mutual_defection_rate']}")
384
+
385
+ Limitations and Considerations
386
+ -----------------------------
387
+
388
+ 1. **Determinism**: The environment is deterministic, with randomness only in initialization if a seed is provided.
389
+
390
+ 2. **Limited Player Count**: The IPD environment only supports exactly two players.
391
+
392
+ 3. **Perfect Information**: Both players have perfect information about the game history.
393
+
394
+ 4. **Simultaneous Actions**: Both players act simultaneously, which requires adaptations for some LLM interfaces.
395
+
396
+ 5. **Fixed Game Length**: The total number of rounds is fixed and known to both players from the start.
397
+
398
+ Advanced Usage
399
+ ------------
400
+
401
+ For advanced usage, you can customize:
402
+
403
+ 1. **Payoff Matrix**: Modify reward values to create different incentive structures.
404
+
405
+ 2. **System Prompts**: Customize the LLM's understanding of the game and potential strategies.
406
+
407
+ 3. **Error Handling**: Adjust how the agent responds to invalid LLM outputs.
408
+
409
+ 4. **Analysis**: Create custom statistics gathering for specific research questions.
410
+
411
+ 5. **Integration**: Connect the IPD environment to other negotiation frameworks or tournament systems.
src_code_for_reproducibility/docs/source/media/runbatch.png ADDED
src_code_for_reproducibility/docs/source/src.environments.dond.dond_game.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.environments.dond.dond\_game module
2
+ =======================================
3
+
4
+ .. automodule:: src.environments.dond.dond_game
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.environments.dond.dond_log_funcs.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.environments.dond.dond\_log\_funcs module
2
+ =============================================
3
+
4
+ .. automodule:: src.environments.dond.dond_log_funcs
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.environments.dond.dond_player.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.environments.dond.dond\_agent module
2
+ =========================================
3
+
4
+ .. automodule:: src.environments.dond.dond_agent
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.environments.env_imports.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.environments.env\_imports module
2
+ ====================================
3
+
4
+ .. automodule:: src.environments.env_imports
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_log_funcs.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.environments.ipd.ipd\_log\_funcs module
2
+ ===========================================
3
+
4
+ .. automodule:: src.environments.ipd.ipd_log_funcs
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.environments.ipd.rst ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ src.environments.ipd package
2
+ ============================
3
+
4
+ .. automodule:: src.environments.ipd
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
8
+
9
+ Submodules
10
+ ----------
11
+
12
+ .. toctree::
13
+ :maxdepth: 4
14
+
15
+ src.environments.ipd.ipd_agent
16
+ src.environments.ipd.ipd_game
17
+ src.environments.ipd.ipd_log_funcs
18
+ src.environments.ipd.ipd_statistics_funcs
19
+ src.environments.ipd.ipd_training_data_funcs
src_code_for_reproducibility/docs/source/src.environments.rst ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ src.environments package
2
+ ========================
3
+
4
+ .. automodule:: src.environments
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
8
+
9
+ Subpackages
10
+ -----------
11
+
12
+ .. toctree::
13
+ :maxdepth: 4
14
+
15
+ src.environments.dond
16
+ src.environments.ipd
17
+
18
+ Submodules
19
+ ----------
20
+
21
+ .. toctree::
22
+ :maxdepth: 4
23
+
24
+ src.environments.env_imports
25
+ src.environments.environment_imports
src_code_for_reproducibility/docs/source/src.experiments.dond_run_train.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.experiments.dond\_run\_train module
2
+ =======================================
3
+
4
+ .. automodule:: src.experiments.dond_run_train
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.generation.run_games.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.generation.run\_games module
2
+ ================================
3
+
4
+ .. automodule:: src.generation.run_games
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.models.hf_agent.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.models.hf\_agent module
2
+ ===========================
3
+
4
+ .. automodule:: src.models.hf_agent
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.models.vllm_worker_wrap.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.models.vllm\_worker\_wrap module
2
+ ====================================
3
+
4
+ .. automodule:: src.models.vllm_worker_wrap
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.run.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.run module
2
+ ==============
3
+
4
+ .. automodule:: src.run
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.training.reinforce_training.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.training.reinforce\_training module
2
+ =======================================
3
+
4
+ .. automodule:: src.training.reinforce_training
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.training.rl_convs_processing.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.training.rl\_convs\_processing module
2
+ =========================================
3
+
4
+ .. automodule:: src.training.rl_convs_processing
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.utils.inherit_args.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.utils.inherit\_args module
2
+ ==============================
3
+
4
+ .. automodule:: src.utils.inherit_args
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.utils.log_gpu_usage.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.utils.log\_gpu\_usage module
2
+ ================================
3
+
4
+ .. automodule:: src.utils.log_gpu_usage
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.utils.log_statistics.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.utils.log\_statistics module
2
+ ================================
3
+
4
+ .. automodule:: src.utils.log_statistics
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.utils.parallel_shuffle.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.utils.parallel\_shuffle module
2
+ ==================================
3
+
4
+ .. automodule:: src.utils.parallel_shuffle
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.utils.quick_stats.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.utils.quick\_stats module
2
+ =============================
3
+
4
+ .. automodule:: src.utils.quick_stats
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/usage.rst ADDED
File without changes
src_code_for_reproducibility/markov_games/__init__.py ADDED
File without changes
src_code_for_reproducibility/markov_games/alternative_actions_runner.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import copy
3
+ import json
4
+ import os.path
5
+ from typing import Any, Tuple
6
+
7
+ from mllm.markov_games.markov_game import AgentAndActionSafeCopy, MarkovGame
8
+ from mllm.markov_games.rollout_tree import (
9
+ AgentActLog,
10
+ RolloutTreeBranchNode,
11
+ RolloutTreeNode,
12
+ RolloutTreeRootNode,
13
+ StepLog,
14
+ )
15
+
16
+ AgentId = str
17
+
18
+
19
+
20
+ async def run_with_unilateral_alt_action(
21
+ markov_game: MarkovGame,
22
+ agent_id: AgentId,
23
+ time_step: int,
24
+ branch_node: RolloutTreeBranchNode,
25
+ max_depth: int,
26
+ ):
27
+ """
28
+ This function is used to generate a new branch for a given agent.
29
+ """
30
+
31
+ # Generate alternative action and take a step
32
+ await markov_game.set_action_of_agent(agent_id)
33
+ terminated: bool = markov_game.take_simulation_step()
34
+ step_log = markov_game.get_step_log()
35
+ first_alternative_node = RolloutTreeNode(
36
+ step_log=step_log,
37
+ time_step=time_step,
38
+ )
39
+
40
+ # Generate rest of trajectory up to max depth
41
+ time_step += 1
42
+ counter = 1
43
+ previous_node = first_alternative_node
44
+ while not terminated and counter <= max_depth:
45
+ terminated, step_log = await markov_game.step()
46
+ current_node = RolloutTreeNode(step_log=step_log, time_step=time_step)
47
+ previous_node.child = current_node
48
+ previous_node = current_node
49
+ counter += 1
50
+ time_step += 1
51
+
52
+ if branch_node.branches == None:
53
+ branch_node.branches = {agent_id: [first_alternative_node]}
54
+ else:
55
+ agent_branches = branch_node.branches.get(agent_id, [])
56
+ agent_branches.append(first_alternative_node)
57
+ branch_node.branches[agent_id] = agent_branches
58
+
59
+
60
+ async def AlternativeActionsRunner(
61
+ markov_game: MarkovGame,
62
+ output_folder: str,
63
+ nb_alternative_actions: int,
64
+ max_depth: int,
65
+ branch_only_on_new_round: bool = False,
66
+ ):
67
+ """
68
+ This method generates a trajectory with partially completed branches,
69
+ where the branching comes from taking unilateraly different actions.
70
+ The resulting data is used to estimate the updated advantage alignment policy gradient terms.
71
+ Let k := nb_sub_steps. Then the number of steps generated is O(Tk), where T is
72
+ the maximum trajectory length.
73
+ """
74
+
75
+ tasks = []
76
+ time_step = 0
77
+ terminated = False
78
+ root = RolloutTreeRootNode(
79
+ id=markov_game.get_id(),
80
+ crn_id=markov_game.get_crn_id()
81
+ )
82
+ previous_node = root
83
+
84
+ while not terminated:
85
+ mg_before_action = markov_game.get_safe_copy()
86
+
87
+ # Get safe copies for main branch
88
+ agent_action_safe_copies: dict[
89
+ AgentId, AgentAndActionSafeCopy
90
+ ] = await markov_game.get_actions_of_agents_without_side_effects()
91
+
92
+ markov_game.set_actions_of_agents_manually(agent_action_safe_copies)
93
+ terminated = markov_game.take_simulation_step()
94
+ main_node = RolloutTreeNode(
95
+ step_log=markov_game.get_step_log(), time_step=time_step
96
+ )
97
+ branch_node = RolloutTreeBranchNode(main_child=main_node)
98
+ previous_node.child = branch_node
99
+ previous_node = main_node
100
+
101
+ # Get alternative branches by generating new unilateral actions
102
+ for agent_id in markov_game.agent_ids:
103
+ for _ in range(nb_alternative_actions):
104
+ # Get safe copies for branches
105
+ branch_agent_action_safe_copies: dict[
106
+ AgentId, AgentAndActionSafeCopy
107
+ ] = {
108
+ agent_id: AgentAndActionSafeCopy(
109
+ action=copy.deepcopy(agent_action_safe_copy.action),
110
+ action_info=copy.deepcopy(agent_action_safe_copy.action_info),
111
+ agent_after_action=agent_action_safe_copy.agent_after_action.get_safe_copy(),
112
+ )
113
+ for agent_id, agent_action_safe_copy in agent_action_safe_copies.items()
114
+ }
115
+ mg_branch: MarkovGame = mg_before_action.get_safe_copy()
116
+ other_agent_id = [id for id in mg_branch.agent_ids if id != agent_id][0]
117
+ mg_branch.set_action_and_agent_after_action_manually(
118
+ agent_id=other_agent_id,
119
+ agent_action_safe_copy=branch_agent_action_safe_copies[
120
+ other_agent_id
121
+ ],
122
+ )
123
+ task = asyncio.create_task(
124
+ run_with_unilateral_alt_action(
125
+ markov_game=mg_branch,
126
+ time_step=time_step,
127
+ agent_id=agent_id,
128
+ branch_node=branch_node,
129
+ max_depth=max_depth,
130
+ )
131
+ )
132
+ tasks.append(task)
133
+ time_step += 1
134
+
135
+ # wait for all branches to complete
136
+ await asyncio.gather(*tasks)
137
+
138
+ return root
src_code_for_reproducibility/markov_games/diplomacy/diplomacy_env.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Optional, Any
2
+ from diplomacy import Game
3
+ import random
4
+
5
+ class DiplomacyEnv:
6
+ """Multi-Agent Reinforcement Learning environment for Diplomacy.
7
+
8
+ This class wraps the Diplomacy game engine to provide an interface
9
+ compliant with the MARL standard.
10
+ """
11
+
12
+ def __init__(self, random_seed=None, map_name="standard", game_id=None, rules=None, max_steps=50):
13
+ """Initialize the Diplomacy environment.
14
+
15
+ Args:
16
+ map_name: The name of the map to use (default: "standard")
17
+ game_id: Optional game ID
18
+ rules: Optional rules to apply to the game
19
+ max_steps: Maximum number of steps before forcing game end (default: 10)
20
+ """
21
+ self.random_seed = random_seed
22
+ self.map_name = map_name
23
+ self.game_id = game_id
24
+ self.rules = rules or []
25
+ self.game = None
26
+ self.active_powers = []
27
+ self.render_mode = None
28
+ self.max_steps = max_steps
29
+ self.current_steps = 0
30
+
31
+ def reset(self):
32
+ """Reset the environment to an initial state and return the initial observation.
33
+
34
+ Returns:
35
+ observation: A dictionary where keys are agent identifiers and values are observations.
36
+ """
37
+ # Initialize a new game
38
+ self.game = Game(game_id=self.game_id, map_name=self.map_name)
39
+
40
+ # Apply rules
41
+ for rule in self.rules:
42
+ self.game.add_rule(rule)
43
+
44
+ # Determine active powers (not eliminated)
45
+ self.active_powers = [name for name, power in self.game.powers.items()
46
+ if not power.is_eliminated()]
47
+
48
+ # Reset step counter
49
+ self.current_steps = 0
50
+
51
+ # Create initial observations for all powers
52
+ observations = {}
53
+ for power_name in self.active_powers:
54
+ observations[power_name] = self._create_observation(power_name)
55
+
56
+ return observations
57
+
58
+ def step(self, actions):
59
+ """Take a step in the environment using the provided actions.
60
+
61
+ Args:
62
+ actions: A dictionary where keys are agent identifiers and values are actions.
63
+
64
+ Returns:
65
+ observations: A dictionary where keys are agent identifiers and values are observations.
66
+ done: Whether the episode has ended.
67
+ info: Additional information about the environment.
68
+ """
69
+ print(f"stepping {self.current_steps}")
70
+ self.current_steps += 1
71
+ # Apply actions (orders) for each power
72
+ for power_name, action in actions.items():
73
+ if power_name in self.active_powers:
74
+ orders = action.get("orders", [])
75
+ wait = action.get("wait", True)
76
+
77
+ # Set orders for the power
78
+ if orders:
79
+ self.game.set_orders(power_name, orders)
80
+
81
+ # Set wait flag
82
+ self.game.set_wait(power_name, wait)
83
+
84
+ # Check if all active powers are ready to proceed
85
+ if self.game.does_not_wait():
86
+ # Process the current phase
87
+ self.game.process()
88
+
89
+
90
+ # Update active powers list after processing
91
+ self.active_powers = [name for name, power in self.game.powers.items()
92
+ if not power.is_eliminated()]
93
+
94
+ # Create observations for all active powers
95
+ observations = {}
96
+ for power_name in self.active_powers:
97
+ observations[power_name] = self._create_observation(power_name)
98
+
99
+ # Check if the game is done (either naturally or due to max steps)
100
+ done = self.game.is_game_done or self.current_steps >= self.max_steps
101
+
102
+ # Create info dict
103
+ info = {
104
+ "phase": self.game.get_current_phase(),
105
+ "active_powers": self.active_powers,
106
+ "centers": self.game.get_centers(),
107
+ "units": self.game.get_units(),
108
+ "current_steps": self.current_steps,
109
+ "max_steps_reached": self.current_steps >= self.max_steps
110
+ }
111
+
112
+ return observations, done, info
113
+
114
+ def _create_observation(self, power_name):
115
+ """Create observation for a specific power.
116
+
117
+ Args:
118
+ power_name: The name of the power
119
+
120
+ Returns:
121
+ An observation dictionary
122
+ """
123
+ observation = {
124
+ "phase": self.game.get_current_phase(),
125
+ "units": self.game.get_units(),
126
+ "centers": self.game.get_centers(),
127
+ "orderable_locations": self.game.get_orderable_locations(power_name),
128
+ "order_status": self.game.get_order_status(power_name),
129
+ "possible_orders": self._get_possible_orders_for_power(power_name)
130
+ }
131
+ return observation
132
+
133
+ def _get_possible_orders_for_power(self, power_name):
134
+ """Get all possible orders for a power's units.
135
+
136
+ Args:
137
+ power_name: The name of the power
138
+
139
+ Returns:
140
+ A dictionary mapping units to their possible orders
141
+ """
142
+ all_possible_orders = self.game.get_all_possible_orders()
143
+
144
+ # Filter for only the locations where this power has units
145
+ power_units = self.game.get_units(power_name)
146
+ power_unit_locations = [unit[2:] for unit in power_units]
147
+
148
+ # For retreat phases, include retreating units
149
+ if self.game.phase_type == 'R':
150
+ power = self.game.get_power(power_name)
151
+ power_unit_locations.extend([unit[2:] for unit in power.retreats])
152
+
153
+ # For adjustment phases, include buildable locations
154
+ elif self.game.phase_type == 'A':
155
+ power = self.game.get_power(power_name)
156
+ # If we have more centers than units, we can build
157
+ if len(power.centers) > len(power.units):
158
+ buildable_sites = self.game._build_sites(power)
159
+ power_unit_locations.extend(buildable_sites)
160
+ # If we have more units than centers, we need to remove
161
+ elif len(power.units) > len(power.centers):
162
+ # All units are candidates for removal
163
+ pass
164
+
165
+ # Filter the possible orders to only those for this power's units/locations
166
+ power_possible_orders = {}
167
+ for loc, orders in all_possible_orders.items():
168
+ if loc[:3] in power_unit_locations:
169
+ power_possible_orders[loc] = orders
170
+
171
+ return power_possible_orders
172
+
173
+ def get_log_info(self):
174
+ """Get additional information about the environment for logging.
175
+
176
+ Returns:
177
+ log_info: Information about the environment required to log the game.
178
+ """
179
+ if not self.game:
180
+ return {}
181
+
182
+ return {
183
+ "game_id": self.game.game_id,
184
+ "phase": self.game.get_current_phase(),
185
+ "map_name": self.game.map_name,
186
+ "centers": self.game.get_centers(),
187
+ "units": self.game.get_units(),
188
+ "powers": {name: {
189
+ "units": power.units,
190
+ "centers": power.centers,
191
+ "is_eliminated": power.is_eliminated(),
192
+ "order_status": self.game.get_order_status(name)
193
+ } for name, power in self.game.powers.items()},
194
+ "orders": self.game.get_orders(),
195
+ "active_powers": self.active_powers,
196
+ "is_game_done": self.game.is_game_done,
197
+ "outcome": self.game.outcome if self.game.is_game_done else None
198
+ }
199
+
200
+ def render(self, mode='human'):
201
+ """Render the current state of the environment.
202
+
203
+ Args:
204
+ mode: The rendering mode ('human', 'svg', etc.)
205
+
206
+ Returns:
207
+ The rendered image if applicable
208
+ """
209
+ self.render_mode = mode
210
+ if self.game:
211
+ if mode == 'human':
212
+ # Just print basic game state
213
+ print(f"Game: {self.game.game_id}")
214
+ print(f"Phase: {self.game.get_current_phase()}")
215
+ print(f"Active Powers: {self.active_powers}")
216
+ print("Supply Centers:")
217
+ for power_name, centers in self.game.get_centers().items():
218
+ print(f" {power_name}: {centers}")
219
+ print("Units:")
220
+ for power_name, units in self.game.get_units().items():
221
+ print(f" {power_name}: {units}")
222
+ return None
223
+ elif mode == 'svg':
224
+ # Return SVG representation
225
+ return self.game.render(output_format='svg')
226
+ return None
227
+
228
+ def close(self):
229
+ """Perform any necessary cleanup."""
230
+ self.game = None
src_code_for_reproducibility/markov_games/gather_and_export_utils.py ADDED
@@ -0,0 +1,951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import os
5
+ import pickle
6
+ import re
7
+ from collections import defaultdict
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
11
+
12
+ from mllm.markov_games.rollout_tree import *
13
+
14
+ try:
15
+ # Re-export moved helpers for backward compatibility
16
+ from basic_render import (
17
+ find_iteration_folders,
18
+ gather_rollout_trees,
19
+ get_rollout_trees,
20
+ )
21
+ except Exception:
22
+ pass
23
+
24
+ # --------------------------------------------------------------------------------------
25
+ # Fetch external rollout trees
26
+ # --------------------------------------------------------------------------------------
27
+
28
+
29
+ def find_iteration_folders(global_folder):
30
+ """Find all iteration_* folders within the global folder structure."""
31
+ global_path = Path(global_folder)
32
+
33
+ # Look for iteration_* folders in all subdirectories
34
+ iteration_folders = []
35
+
36
+ # Search in the global folder itself
37
+ for item in global_path.glob("iteration_*"):
38
+ if item.is_dir():
39
+ iteration_folders.append(item)
40
+
41
+ # Search in seed_* subdirectories
42
+ for seed_dir in global_path.glob("seed_*/"):
43
+ if seed_dir.is_dir():
44
+ for item in seed_dir.glob("iteration_*"):
45
+ if item.is_dir():
46
+ iteration_folders.append(item)
47
+
48
+ return sorted(iteration_folders)
49
+
50
+
51
+ def gather_rollout_trees(iteration_folder):
52
+ """Gather all rollout trees from the iteration folder (.pkl only)."""
53
+ rollout_trees = []
54
+ iteration_path = Path(iteration_folder)
55
+ for item in iteration_path.glob("**/*.rt.pkl"):
56
+ with open(item, "rb") as f:
57
+ data = pickle.load(f)
58
+ # Validate dicts back into Pydantic model for downstream use
59
+ rollout_tree = RolloutTreeRootNode.model_validate(data)
60
+ rollout_trees.append(rollout_tree)
61
+ return rollout_trees
62
+
63
+
64
+ def get_rollout_trees(global_folder) -> list[list[RolloutTreeRootNode]]:
65
+ """Get all rollout trees from the global folder."""
66
+ iteration_folders = find_iteration_folders(global_folder)
67
+ rollout_trees = []
68
+ for iteration_folder in iteration_folders:
69
+ rollout_trees.append(gather_rollout_trees(iteration_folder))
70
+ return rollout_trees
71
+
72
+
73
+ # --------------------------------------------------------------------------------------
74
+ # Gather data from rollout tree methods
75
+ # --------------------------------------------------------------------------------------
76
+
77
+
78
+ def load_rollout_tree(path: Path) -> RolloutTreeRootNode:
79
+ """Load a rollout tree from a PKL file containing a dict."""
80
+ with open(path, "rb") as f:
81
+ data = pickle.load(f)
82
+ return RolloutTreeRootNode.model_validate(data)
83
+
84
+
85
+ @dataclass
86
+ class RolloutNodeList:
87
+ id: str
88
+ nodes: List[RolloutTreeNode]
89
+
90
+
91
+ def get_rollout_tree_paths(
92
+ root: RolloutTreeRootNode, mgid: Optional[str] = None
93
+ ) -> Tuple[RolloutNodeList, List[RolloutNodeList]]:
94
+ """
95
+ Returns:
96
+ main_path: The main path from the root to the end of the tree.
97
+ branch_paths: A list of all branch paths from the root to the end of the tree.
98
+ Each branch path contains a list of nodes that are part of the branch, including the nodes from the main path before the branch was taken.
99
+ """
100
+ branch_paths = []
101
+
102
+ def collect_path_nodes(current) -> List[RolloutTreeNode]:
103
+ """Recursively collect all nodes in a path starting from current node."""
104
+ if current is None:
105
+ return []
106
+
107
+ if isinstance(current, RolloutTreeNode):
108
+ return [current] + collect_path_nodes(current.child)
109
+
110
+ elif isinstance(current, RolloutTreeBranchNode):
111
+ # For branch nodes, we only follow the main_child for path collection
112
+ if current.main_child:
113
+ return [current.main_child] + collect_path_nodes(
114
+ current.main_child.child
115
+ )
116
+ else:
117
+ return []
118
+
119
+ def traverse_for_branches(
120
+ current,
121
+ main_path_prefix: List[RolloutTreeNode],
122
+ path_id: str,
123
+ current_time_step: Optional[int] = 0,
124
+ ):
125
+ """Traverse tree to collect all branch paths."""
126
+ if current is None:
127
+ return
128
+
129
+ if isinstance(current, RolloutTreeNode):
130
+ # Continue traversing with this node added to the main path prefix
131
+ new_prefix = main_path_prefix + [current]
132
+ traverse_for_branches(current.child, new_prefix, path_id, current.time_step)
133
+
134
+ elif isinstance(current, RolloutTreeBranchNode):
135
+ # Collect all branch paths
136
+ if current.branches:
137
+ for agent_id, branch_node_list in current.branches.items():
138
+ if branch_node_list:
139
+ # Start with the main path prefix, then recursively collect all nodes in this branch
140
+ branch_path_nodes = main_path_prefix.copy()
141
+ for branch_node in branch_node_list:
142
+ branch_path_nodes.extend(collect_path_nodes(branch_node))
143
+
144
+ # Create proper branch path ID with mgid, agent_id, and time_step
145
+ mgid_str = mgid or str(root.id)
146
+ branch_path_id = f"mgid:{mgid_str}_type:branch_agent:{agent_id}_time_step:{current_time_step}"
147
+ branch_paths.append(
148
+ RolloutNodeList(id=branch_path_id, nodes=branch_path_nodes)
149
+ )
150
+
151
+ # Process the main child and add to prefix
152
+ new_prefix = main_path_prefix
153
+ if current.main_child:
154
+ new_prefix = main_path_prefix + [current.main_child]
155
+
156
+ # Continue traversing the main path
157
+ if current.main_child:
158
+ traverse_for_branches(
159
+ current.main_child.child,
160
+ new_prefix,
161
+ path_id,
162
+ current.main_child.time_step,
163
+ )
164
+
165
+ # Collect the main path nodes
166
+ main_path_nodes = collect_path_nodes(root.child)
167
+
168
+ # Traverse to collect all branch paths
169
+ traverse_for_branches(root.child, [], "")
170
+
171
+ # Create the main path with proper mgid format
172
+ mgid_str = mgid or str(root.id)
173
+ main_path = RolloutNodeList(id=f"mgid:{mgid_str}_type:main", nodes=main_path_nodes)
174
+
175
+ return main_path, branch_paths
176
+
177
+
178
+ class ChatTurnLog(BaseModel):
179
+ time_step: int
180
+ agent_id: str
181
+ role: str
182
+ content: str
183
+ reasoning: Optional[str] = None
184
+ is_state_end: bool
185
+ reward: float
186
+
187
+
188
+ def gather_agent_chat_turns_for_path(
189
+ agent_id: str, path: RolloutNodeList
190
+ ) -> List[ChatTurnLog]:
191
+ """Iterate through all chat turns for a specific agent in a path sorted by time step."""
192
+ turns = []
193
+ for node in path.nodes:
194
+ action_log = node.step_log.action_logs.get(agent_id, [])
195
+ if action_log:
196
+ for chat_turn in action_log.chat_turns or []:
197
+ turns.append(
198
+ ChatTurnLog(
199
+ time_step=node.time_step,
200
+ agent_id=agent_id,
201
+ role=chat_turn.role,
202
+ content=chat_turn.content,
203
+ reasoning=getattr(chat_turn, "reasoning_content", None),
204
+ is_state_end=chat_turn.is_state_end,
205
+ reward=node.step_log.simulation_step_log.rewards.get(
206
+ agent_id, 0
207
+ ),
208
+ )
209
+ )
210
+ return turns
211
+
212
+
213
+ def gather_all_chat_turns_for_path(path: RolloutNodeList) -> List[ChatTurnLog]:
214
+ """Iterate through all chat turns for all agents in a path sorted by time step."""
215
+ turns = []
216
+
217
+ # Collect turns from all agents, but interleave them per timestep by (user, assistant) pairs
218
+ for node in path.nodes:
219
+ # Build (user[, assistant]) pairs for each agent at this timestep
220
+ agent_ids = sorted(list(node.step_log.action_logs.keys()))
221
+ per_agent_pairs: Dict[str, List[List[ChatTurnLog]]] = {}
222
+
223
+ for agent_id in agent_ids:
224
+ action_log = node.step_log.action_logs.get(agent_id)
225
+ pairs: List[List[ChatTurnLog]] = []
226
+ current_pair: List[ChatTurnLog] = []
227
+
228
+ if action_log and action_log.chat_turns:
229
+ for chat_turn in action_log.chat_turns:
230
+ turn_log = ChatTurnLog(
231
+ time_step=node.time_step,
232
+ agent_id=agent_id,
233
+ role=chat_turn.role,
234
+ content=chat_turn.content,
235
+ reasoning=getattr(chat_turn, "reasoning_content", None),
236
+ is_state_end=chat_turn.is_state_end,
237
+ reward=node.step_log.simulation_step_log.rewards.get(
238
+ agent_id, 0
239
+ ),
240
+ )
241
+
242
+ if chat_turn.role == "user":
243
+ # If a previous pair is open, close it and start a new one
244
+ if current_pair:
245
+ pairs.append(current_pair)
246
+ current_pair = []
247
+ current_pair = [turn_log]
248
+ else:
249
+ # assistant: attach to an open user message if present; otherwise stand alone
250
+ if (
251
+ current_pair
252
+ and len(current_pair) == 1
253
+ and current_pair[0].role == "user"
254
+ ):
255
+ current_pair.append(turn_log)
256
+ pairs.append(current_pair)
257
+ current_pair = []
258
+ else:
259
+ # No preceding user or already paired; treat as its own unit
260
+ pairs.append([turn_log])
261
+
262
+ if current_pair:
263
+ # Unpaired trailing user message
264
+ pairs.append(current_pair)
265
+
266
+ per_agent_pairs[agent_id] = pairs
267
+
268
+ # Interleave pairs across agents: A1, B1, A2, B2, ...
269
+ index = 0
270
+ while True:
271
+ added_any = False
272
+ for agent_id in agent_ids:
273
+ agent_pairs = per_agent_pairs.get(agent_id, [])
274
+ if index < len(agent_pairs):
275
+ for tl in agent_pairs[index]:
276
+ turns.append(tl)
277
+ added_any = True
278
+ if not added_any:
279
+ break
280
+ index += 1
281
+
282
+ return turns
283
+
284
+
285
+ def chat_turns_to_dict(chat_turns: Iterator[ChatTurnLog]) -> Iterator[Dict[str, Any]]:
286
+ """Render all chat turns for a path as structured data for JSON."""
287
+ for chat_turn in chat_turns:
288
+ yield chat_turn.model_dump()
289
+
290
+
291
+ def get_all_agents(root: RolloutTreeRootNode) -> List[str]:
292
+ """list of all agent IDs that appear in the tree."""
293
+ if root.child is None:
294
+ return []
295
+
296
+ # Get the first node to extract all agent IDs
297
+ first_node = root.child
298
+ if isinstance(first_node, RolloutTreeBranchNode):
299
+ first_node = first_node.main_child
300
+
301
+ if first_node is None:
302
+ return []
303
+
304
+ # All agents should be present in the first node
305
+ agents = set(first_node.step_log.action_logs.keys())
306
+ agents.update(first_node.step_log.simulation_step_log.rewards.keys())
307
+
308
+ return sorted(list(agents))
309
+
310
+
311
+ def gather_agent_main_rewards(agent_id: str, path: RolloutNodeList) -> List[float]:
312
+ """Gather main rewards for a specific agent in a path."""
313
+ rewards = []
314
+ for node in path.nodes:
315
+ reward = node.step_log.simulation_step_log.rewards[agent_id]
316
+ rewards.append(reward)
317
+ return rewards
318
+
319
+
320
+ def gather_all_rewards(path: RolloutNodeList) -> List[Dict[AgentId, float]]:
321
+ """Gather main rewards from main trajectory in a path."""
322
+ rewards = []
323
+ for node in path.nodes:
324
+ rewards.append(node.step_log.simulation_step_log.rewards.copy())
325
+ return rewards
326
+
327
+
328
+ def gather_simulation_stats(
329
+ path: RolloutNodeList,
330
+ filter: Callable[[SimulationStepLog], bool],
331
+ stat_func: Callable[[SimulationStepLog], Any],
332
+ ) -> List[Any]:
333
+ """Gather stats from main trajectory in a path."""
334
+ stats = []
335
+ for node in path.nodes:
336
+ sl = node.step_log.simulation_step_log
337
+ if filter(sl):
338
+ stats.append(stat_func(sl))
339
+ return stats
340
+
341
+
342
+ def gather_simulation_infos(path: RolloutNodeList) -> List[Dict[str, Any]]:
343
+ """Gather simulation information from main trajectory in a path."""
344
+ infos = []
345
+ for node in path.nodes:
346
+ infos.append(node.step_log.simulation_step_log.info)
347
+ return infos
348
+
349
+
350
+ def export_chat_logs(path: Path, outdir: Path):
351
+ """Process a rollout tree PKL file and generate a JSONL of chat turns as dicts.
352
+ Each line contains an object with path_id and chat_turns for a single path.
353
+ """
354
+ import json
355
+
356
+ root = load_rollout_tree(path)
357
+ mgid = root.id
358
+
359
+ main_path, branch_paths = get_rollout_tree_paths(root)
360
+ all_paths = [main_path] + branch_paths
361
+
362
+ outdir.mkdir(parents=True, exist_ok=True)
363
+ output_file = outdir / f"mgid:{mgid}_plucked_chats.render.jsonl"
364
+
365
+ with open(output_file, "w", encoding="utf-8") as f:
366
+ for path_obj in all_paths:
367
+ chat_turns = gather_all_chat_turns_for_path(path_obj)
368
+ output_obj = {
369
+ "path_id": str(path_obj.id),
370
+ "chat_turns": list(chat_turns_to_dict(iter(chat_turns))),
371
+ }
372
+ f.write(json.dumps(output_obj, ensure_ascii=False) + "\n")
373
+
374
+
375
+ def export_rewards_to_csv(path: Path, outdir: Path, first_file: bool):
376
+ # Load the rollout tree
377
+ root = load_rollout_tree(path)
378
+ mgid = root.id
379
+
380
+ # Get all paths
381
+ main_path, branch_paths = get_rollout_tree_paths(root)
382
+ outdir.mkdir(parents=True, exist_ok=True)
383
+ rewards_dict_list = gather_all_rewards(main_path)
384
+ agent_ids = rewards_dict_list[0].keys()
385
+ rewards_list = defaultdict(list)
386
+ for rewards_dict in rewards_dict_list:
387
+ for agent_id in agent_ids:
388
+ rewards_list[agent_id].append(rewards_dict[agent_id])
389
+
390
+ mgid = root.id
391
+ group_seed = getattr(root, "crn_id", None)
392
+
393
+ for agent_id in agent_ids:
394
+ output_file = outdir / f"agent:{agent_id}_rewards.render.csv"
395
+
396
+ # Build current row: [mgid, group_seed] + rewards
397
+ formatted_rewards = [f"{round(x, 1):>5}" for x in rewards_list[agent_id]]
398
+ current_row = [str(mgid), str(group_seed)] + formatted_rewards
399
+
400
+ # Read existing rows (if any), skipping header if present
401
+ existing_rows: List[List[str]] = []
402
+ if output_file.exists():
403
+ with open(output_file, "r", newline="") as rf:
404
+ reader = csv.reader(rf)
405
+ for row in reader:
406
+ if not row or not any(cell.strip() for cell in row):
407
+ continue
408
+ if (
409
+ len(row) >= 2
410
+ and row[0].strip().lower() == "mgid"
411
+ and row[1].strip().lower() == "group_seed"
412
+ ):
413
+ # skip header
414
+ continue
415
+ existing_rows.append(row)
416
+
417
+ # Append and sort by (group_seed, mgid)
418
+ existing_rows.append(current_row)
419
+
420
+ def sort_key(r: List[str]):
421
+ def try_int(val: str):
422
+ try:
423
+ return int(val)
424
+ except Exception:
425
+ return None
426
+
427
+ seed_raw = r[1] if len(r) > 1 else ""
428
+ mgid_raw = r[0] if len(r) > 0 else ""
429
+ seed_num = try_int(seed_raw)
430
+ mgid_num = try_int(mgid_raw)
431
+ # Sort numerically when possible; otherwise fall back to string
432
+ return (
433
+ 0 if seed_num is not None else 1,
434
+ seed_num if seed_num is not None else seed_raw,
435
+ 0 if mgid_num is not None else 1,
436
+ mgid_num if mgid_num is not None else mgid_raw,
437
+ )
438
+
439
+ existing_rows.sort(key=sort_key)
440
+
441
+ # Determine max reward length to build header and pad rows
442
+ max_reward_len = 0
443
+ for r in existing_rows:
444
+ if len(r) > 2:
445
+ max_reward_len = max(max_reward_len, len(r) - 2)
446
+ max_reward_len = max(max_reward_len, len(current_row) - 2)
447
+
448
+ def pad_row(r: List[str]) -> List[str]:
449
+ needed = (2 + max_reward_len) - len(r)
450
+ return r + ([""] * needed if needed > 0 else [])
451
+
452
+ padded_rows = [pad_row(r) for r in existing_rows]
453
+
454
+ # Build header
455
+ header = ["mgid", "group_seed"] + [f"r_t{t}" for t in range(max_reward_len)]
456
+
457
+ # Rewrite the file with header to avoid extra/blank rows
458
+ with open(output_file, "w", newline="") as wf:
459
+ writer = csv.writer(wf)
460
+ writer.writerow(header)
461
+ writer.writerows(padded_rows)
462
+
463
+
464
+ # --------------------------------------------------------------------------------------
465
+ # HTML exports
466
+ # --------------------------------------------------------------------------------------
467
+
468
+
469
+ def html_from_chat_turns(chat_turns: List[ChatTurnLog]) -> str:
470
+ """
471
+ Render chat turns as a single, wrapping sequence of messages in time order.
472
+ Keep badge and message bubble styles, include time on every badge and
473
+ include rewards on assistant badges. Each message is individually
474
+ hide/show by click; when hidden, only the badge remains and "(...)" is
475
+ shown inline (not inside a bubble).
476
+ """
477
+ import html
478
+
479
+ # Prepare ordering: sort by (time_step, original_index) to keep stable order within same step
480
+ indexed_turns = list(enumerate(chat_turns))
481
+ indexed_turns.sort(key=lambda t: (t[1].time_step, t[0]))
482
+
483
+ # CSS styles (simplified layout; no time-step or agent-column backgrounds)
484
+ css = """
485
+ <style>
486
+ :root {
487
+ --font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
488
+ --bg: #ffffff;
489
+ --text: #1c0b00;
490
+ --muted-text: #2C3E50;
491
+ --accent-muted: #BDC3C7;
492
+ --accent-muted-2: #D0D7DE;
493
+ --panel-bg: #F8FAFC;
494
+ --reward-color: #a89206;
495
+ --font-size: 15px;
496
+ --small-font-size: 13px;
497
+ --group-label-font-size: 12px;
498
+ --border-width: 2px;
499
+ --corner-radius: 6px;
500
+ --pill-radius-left: 999px 0 0 999px;
501
+ --pill-radius-right: 0 999px 999px 0;
502
+ --inset-shadow: 0 1px 0 rgba(0,0,0,0.03) inset;
503
+ }
504
+ body {
505
+ font-family: var(--font-family);
506
+ margin: 16px;
507
+ background-color: var(--bg);
508
+ color: var(--text);
509
+ font-size: var(--font-size);
510
+ line-height: 1.6;
511
+ }
512
+ .messages-flow {
513
+ display: block; /* behave like a text container */
514
+ }
515
+ .toolbar {
516
+ display: flex;
517
+ align-items: center;
518
+ gap: 8px;
519
+ margin-bottom: 0;
520
+ font-size: var(--small-font-size);
521
+ max-height: 0;
522
+ overflow: hidden;
523
+ opacity: 0;
524
+ pointer-events: none;
525
+ transition: max-height 0.2s ease, opacity 0.2s ease;
526
+ }
527
+ .toolbar-wrap { position: sticky; top: 0; z-index: 10; background: var(--bg); }
528
+ .toolbar-hotzone { height: 6px; }
529
+ .toolbar-wrap:hover .toolbar { max-height: 200px; opacity: 1; pointer-events: auto; margin-bottom: 12px; }
530
+ .toolbar input[type="number"] {
531
+ width: 72px;
532
+ padding: 2px 6px;
533
+ border: 1px solid var(--accent-muted);
534
+ border-radius: var(--corner-radius);
535
+ background: var(--bg);
536
+ }
537
+ .toolbar button {
538
+ padding: 4px 8px;
539
+ border: 1px solid var(--accent-muted);
540
+ background: var(--panel-bg);
541
+ border-radius: var(--corner-radius);
542
+ cursor: pointer;
543
+ }
544
+ .chat-turn {
545
+ display: inline; /* inline like text */
546
+ background: transparent;
547
+ position: relative;
548
+ cursor: default;
549
+ }
550
+ /* No agent-specific background distinctions */
551
+ .turn-content {
552
+ white-space: normal;
553
+ color: var(--text);
554
+ font-size: var(--font-size);
555
+ display: inline; /* inline flow */
556
+ }
557
+ .chat-turn .agent-badge { margin-right: 0; vertical-align: baseline; }
558
+ .agent-badge {
559
+ display: inline;
560
+ position: relative;
561
+ border: var(--border-width) solid var(--accent-muted); /* slightly thicker */
562
+ border-radius: var(--pill-radius-left); /* round left and bottom-right */
563
+ font-size: var(--font-size);
564
+ color: var(--muted-text);
565
+ background: transparent;
566
+ box-shadow: var(--inset-shadow);
567
+ line-height: 1.35;
568
+ padding: 3px 10px;
569
+ border-right: 0;
570
+ cursor: default;
571
+ }
572
+ .agent-badge::after {
573
+ content: none;
574
+ }
575
+ /* removed external separator; emoji is rendered inside message bubble */
576
+ .agent-name { font-weight: 700; }
577
+ .emoji-bw { filter: grayscale(100%); opacity: 0.95; font-size: var(--font-size); vertical-align: baseline; margin: 0; position: relative; top: -1px; line-height: 1; display: inline-block; }
578
+ .ts-badge {
579
+ position: relative;
580
+ display: inline;
581
+ border: var(--border-width) solid var(--accent-muted-2); /* slightly thicker */
582
+ border-radius: var(--corner-radius); /* not a pill */
583
+ font-size: var(--font-size);
584
+ font-weight: 700;
585
+ color: var(--muted-text);
586
+ background: #F4F8FB; /* subtle tint */
587
+ padding: 1px 6px; /* slight padding for visibility */
588
+ margin-right: 8px; /* small gap from following content */
589
+ pointer-events: auto; /* allow events so we can ignore them in JS */
590
+ }
591
+ /* Hide timestep badges when grouping by 1 */
592
+ .hide-ts-badges .ts-badge { display: none; }
593
+ /* Strong hide: completely hide collapsed segments */
594
+ .strong-hide .segment.collapsed { display: none; }
595
+ .ts-badge::before {
596
+ content: "";
597
+ position: relative;
598
+ background: var(--accent-muted-2);
599
+ border-radius: 2px;
600
+ }
601
+ .agent-badge { margin-left: 6px; }
602
+ /* Segments (reasoning and message) */
603
+ .segment {
604
+ display: inline; /* inline bubble behaving like text */
605
+ font-size: var(--font-size);
606
+ position: relative;
607
+ background: var(--bg);
608
+ vertical-align: baseline;
609
+ line-height: 1.35;
610
+ cursor: pointer;
611
+ }
612
+ .message-box, .reasoning-box {
613
+ display: inline; /* inline bubble behaving like text */
614
+ font-size: var(--font-size);
615
+ border: var(--border-width) solid var(--accent-muted);
616
+ border-radius: var(--pill-radius-right); /* message defaults to pill-right */
617
+ position: relative;
618
+ background: var(--bg);
619
+ vertical-align: baseline;
620
+ line-height: 1.35;
621
+ padding: 3px 10px;
622
+ border-left: 0;
623
+ }
624
+ /* Reasoning between badge and message: no left or right rounding, seam on both sides */
625
+ .reasoning-box {
626
+ border-radius: 0;
627
+ border-left: 0;
628
+ border-right: 0;
629
+ }
630
+ /* Reasoning text style: slightly smaller and slightly gray */
631
+ .reasoning-box .seg-text {
632
+ font-size: var(--small-font-size);
633
+ color: #6b7280;
634
+ }
635
+ .message-box::before { content: none; display: none; margin-right: 0; line-height: 1; }
636
+ .reasoning-box::before { content: none; display: none; margin-right: 0; line-height: 1; }
637
+ /* Segment collapsed behavior */
638
+ .segment .seg-text { display: inline; }
639
+ .segment.collapsed .seg-text { color: transparent; font-size: 0; display: inline-block; }
640
+ .segment.collapsed::after { content: "(...)"; color: #7f8c8d; font-style: italic; font-size: var(--font-size); line-height: 1.2; }
641
+ .segment.collapsed .emoji-bw { opacity: 0.3; }
642
+ .chat-turn.agent-alice.role-assistant .message-box::before { color: #0eb224; }
643
+ .chat-turn.agent-bob.role-assistant .message-box::before { color: #ef8323; }
644
+ .chat-turn.collapsed .message-box::before { display: none; }
645
+ /* Assistant bubble border colors by common agent names */
646
+ .chat-turn.agent-alice.role-assistant .message-box { border-color: #0eb224; }
647
+ .chat-turn.agent-bob.role-assistant .message-box { border-color: #ef8323; }
648
+ .chat-turn.agent-alice.role-assistant .reasoning-box { border-color: #0eb224; }
649
+ .chat-turn.agent-bob.role-assistant .reasoning-box { border-color: #ef8323; }
650
+ /* Tie badge and seam to agent color for a cohesive capsule, assistants only */
651
+ .chat-turn.agent-alice.role-assistant .agent-badge { border-color: #0eb224; }
652
+ .chat-turn.agent-alice.role-assistant .agent-badge::after { border-right-color: #0eb224; }
653
+ .chat-turn.agent-alice.role-assistant .turn-content::before { border-left-color: #0eb224; border-top-color: #0eb224; }
654
+ .chat-turn.agent-alice.role-assistant .message-box { border-color: #0eb224; }
655
+
656
+ .chat-turn.agent-bob.role-assistant .agent-badge { border-color: #ef8323; }
657
+ .chat-turn.agent-bob.role-assistant .agent-badge::after { border-right-color: #ef8323; }
658
+ .chat-turn.agent-bob.role-assistant .turn-content::before { border-left-color: #ef8323; border-top-color: #ef8323; }
659
+ .chat-turn.agent-bob.role-assistant .message-box { border-color: #ef8323; }
660
+ /* No colored agent-name; keep neutral */
661
+ .reward { color: var(--reward-color); font-weight: 600; } /* dark gold */
662
+ .message-placeholder { display: none; color: #7f8c8d; font-style: italic; }
663
+ /* Group divider - clearer and pretty */
664
+ .group-divider {
665
+ display: flex;
666
+ align-items: center;
667
+ gap: 8px;
668
+ width: 100%;
669
+ margin: 8px 0 2px 0;
670
+ }
671
+ .group-divider::before,
672
+ .group-divider::after {
673
+ content: "";
674
+ flex: 1 1 auto;
675
+ height: 2px;
676
+ background: linear-gradient(90deg, rgba(224,230,235,0), var(--accent-muted-2) 30%, var(--accent-muted-2) 70%, rgba(224,230,235,0));
677
+ }
678
+ .group-divider .group-label {
679
+ display: inline-block;
680
+ border: 1px solid var(--accent-muted);
681
+ border-radius: 999px;
682
+ padding: 2px 10px;
683
+ font-size: var(--group-label-font-size);
684
+ font-weight: 700;
685
+ color: var(--muted-text);
686
+ background: var(--bg);
687
+ box-shadow: var(--inset-shadow);
688
+ }
689
+ .chat-turn .turn-content { position: relative; }
690
+ .chat-turn .turn-content::before {
691
+ content: none;
692
+ }
693
+ .chat-turn .agent-badge {
694
+ position: relative;
695
+ }
696
+ /* removed absolute-positioned emoji to prevent overlap */
697
+ </style>
698
+ """
699
+
700
+ # HTML structure
701
+ html_parts = [
702
+ "<!DOCTYPE html>",
703
+ "<html>",
704
+ "<head>",
705
+ "<meta charset='UTF-8'>",
706
+ "<title>Chat Turns</title>",
707
+ css,
708
+ "<script>\n"
709
+ "document.addEventListener('DOMContentLoaded', function() {\n"
710
+ " const flow = document.querySelector('.messages-flow');\n"
711
+ " // State for range filtering and strong hide\n"
712
+ " let currentRangeStart = null;\n"
713
+ " let currentRangeEnd = null;\n"
714
+ " let strongHideOn = false;\n"
715
+ " // Toggle collapse per message\n"
716
+ " document.body.addEventListener('click', function(e){\n"
717
+ " if (e.target.closest('.ts-badge')) { return; }\n"
718
+ " if (e.target.closest('.agent-badge')) { return; }\n"
719
+ " const seg = e.target.closest('.segment');\n"
720
+ " if (seg) { e.stopPropagation(); seg.classList.toggle('collapsed'); }\n"
721
+ " });\n"
722
+ " // Grouping logic\n"
723
+ " function applyRangeFilter() {\n"
724
+ " const turns = Array.from(flow.querySelectorAll('.chat-turn'));\n"
725
+ " for (const el of turns) {\n"
726
+ " const t = parseInt(el.getAttribute('data-time-step') || '0', 10);\n"
727
+ " const afterStart = (currentRangeStart === null) || (t >= currentRangeStart);\n"
728
+ " const beforeEnd = (currentRangeEnd === null) || (t <= currentRangeEnd);\n"
729
+ " el.style.display = (afterStart && beforeEnd) ? '' : 'none';\n"
730
+ " }\n"
731
+ " // Hide group headers that have no visible turns in their section\n"
732
+ " const dividers = Array.from(flow.querySelectorAll('.group-divider'));\n"
733
+ " for (const d of dividers) {\n"
734
+ " let anyVisible = false;\n"
735
+ " let el = d.nextElementSibling;\n"
736
+ " while (el && !el.classList.contains('group-divider')) {\n"
737
+ " if (el.classList.contains('chat-turn')) {\n"
738
+ " const disp = getComputedStyle(el).display;\n"
739
+ " if (disp !== 'none') { anyVisible = true; break; }\n"
740
+ " }\n"
741
+ " el = el.nextElementSibling;\n"
742
+ " }\n"
743
+ " d.style.display = anyVisible ? '' : 'none';\n"
744
+ " }\n"
745
+ " }\n"
746
+ " function applyGrouping(n) {\n"
747
+ " // Remove existing group dividers\n"
748
+ " Array.from(flow.querySelectorAll('.group-divider')).forEach(el => el.remove());\n"
749
+ " if (!n || n <= 0) { return; }\n"
750
+ " const turns = Array.from(flow.querySelectorAll('.chat-turn'));\n"
751
+ " if (turns.length === 0) return;\n"
752
+ " // Re-append in order with dividers\n"
753
+ " const items = Array.from(flow.children).filter(el => !el.classList.contains('group-divider'));\n"
754
+ " const frag = document.createDocumentFragment();\n"
755
+ " let lastGroup = -1;\n"
756
+ " for (const el of items) {\n"
757
+ " if (!el.classList.contains('chat-turn')) { frag.appendChild(el); continue; }\n"
758
+ " const t = parseInt(el.getAttribute('data-time-step') || '0', 10);\n"
759
+ " const g = Math.floor(t / n);\n"
760
+ " if (g !== lastGroup) {\n"
761
+ " const div = document.createElement('div');\n"
762
+ " div.className = 'group-divider';\n"
763
+ " const label = document.createElement('span');\n"
764
+ " label.className = 'group-label';\n"
765
+ " const start = g * n;\n"
766
+ " const end = start + n - 1;\n"
767
+ " const roundIndex = g + 1;\n"
768
+ " label.textContent = `Round ${roundIndex}`;\n"
769
+ " div.appendChild(label);\n"
770
+ " frag.appendChild(div);\n"
771
+ " lastGroup = g;\n"
772
+ " }\n"
773
+ " frag.appendChild(el);\n"
774
+ " }\n"
775
+ " flow.innerHTML = '';\n"
776
+ " flow.appendChild(frag);\n"
777
+ " // Hide timestep badges when grouping is 1\n"
778
+ " flow.classList.toggle('hide-ts-badges', n === 1);\n"
779
+ " // Keep strong hide state\n"
780
+ " flow.classList.toggle('strong-hide', strongHideOn);\n"
781
+ " // Re-apply range filter after regrouping\n"
782
+ " applyRangeFilter();\n"
783
+ " }\n"
784
+ " const input = document.getElementById('group-size');\n"
785
+ " const btn = document.getElementById('apply-grouping');\n"
786
+ " if (btn && input) {\n"
787
+ " btn.addEventListener('click', () => { const n = parseInt(input.value || '0', 10); applyGrouping(n); });\n"
788
+ " input.addEventListener('keydown', (e) => { if (e.key === 'Enter') { const n = parseInt(input.value || '0', 10); applyGrouping(n); } });\n"
789
+ " }\n"
790
+ " // Default grouping to 1 timestep on load\n"
791
+ " if (input) { input.value = '1'; applyGrouping(1); }\n"
792
+ " // Range filter controls\n"
793
+ " const rangeStart = document.getElementById('range-start');\n"
794
+ " const rangeEnd = document.getElementById('range-end');\n"
795
+ " const rangeBtn = document.getElementById('apply-range');\n"
796
+ " if (rangeBtn && rangeStart && rangeEnd) {\n"
797
+ " const applyRange = () => {\n"
798
+ " const sv = parseInt(rangeStart.value || '', 10);\n"
799
+ " const ev = parseInt(rangeEnd.value || '', 10);\n"
800
+ " currentRangeStart = Number.isFinite(sv) ? sv : null;\n"
801
+ " currentRangeEnd = Number.isFinite(ev) ? ev : null;\n"
802
+ " applyRangeFilter();\n"
803
+ " };\n"
804
+ " rangeBtn.addEventListener('click', applyRange);\n"
805
+ " rangeStart.addEventListener('keydown', (e) => { if (e.key === 'Enter') applyRange(); });\n"
806
+ " rangeEnd.addEventListener('keydown', (e) => { if (e.key === 'Enter') applyRange(); });\n"
807
+ " }\n"
808
+ " // Strong hide toggle (on by default)\n"
809
+ " const strongHideBtn = document.getElementById('toggle-strong-hide');\n"
810
+ " const strongHideStateEl = document.getElementById('strong-hide-state');\n"
811
+ " if (strongHideBtn) {\n"
812
+ " const setLabel = () => { if (strongHideStateEl) { strongHideStateEl.textContent = strongHideOn ? 'On' : 'Off'; } };\n"
813
+ " strongHideBtn.addEventListener('click', () => { strongHideOn = !strongHideOn; flow.classList.toggle('strong-hide', strongHideOn); setLabel(); });\n"
814
+ " flow.classList.toggle('strong-hide', strongHideOn);\n"
815
+ " setLabel();\n"
816
+ " }\n"
817
+ "});\n"
818
+ "</script>",
819
+ "</head>",
820
+ "<body>",
821
+ '<div class="toolbar-wrap">',
822
+ '<div class="toolbar-hotzone"></div>',
823
+ '<div class="toolbar">',
824
+ '<label for="group-size">Group every</label>',
825
+ '<input id="group-size" type="number" min="0" step="1" value="1" />',
826
+ "<span>timesteps</span>",
827
+ '<button id="apply-grouping">Apply</button>',
828
+ '<span style="margin-left:8px"></span>',
829
+ '<label for="range-start"><span class="emoji-bw">🔎</span> Range</label>',
830
+ '<input id="range-start" type="number" step="1" />',
831
+ "<span>to</span>",
832
+ '<input id="range-end" type="number" step="1" />',
833
+ '<button id="apply-range"><span class="emoji-bw">▶︎</span> Apply</button>',
834
+ '<button id="toggle-strong-hide"><span class="emoji-bw">🗜️</span> Strong Hide: <span id="strong-hide-state">On</span></button>',
835
+ "</div>",
836
+ "</div>",
837
+ '<div class="messages-flow">',
838
+ ]
839
+
840
+ last_time_step = None
841
+ for original_index, turn in indexed_turns:
842
+ # Build classes
843
+ agent_class = f"agent-{re.sub('[^a-z0-9_-]', '-', turn.agent_id.lower())}"
844
+ role_class = f"role-{turn.role}"
845
+ # Segments default collapsed for user role
846
+ segment_collapsed_class = " collapsed" if turn.role == "user" else ""
847
+
848
+ # Badge content
849
+ if turn.role == "assistant":
850
+ name = html.escape(turn.agent_id)
851
+ emoji = '<span class="emoji-bw">🤖</span>'
852
+ raw_val = turn.reward
853
+ if isinstance(raw_val, (int, float)):
854
+ reward_val = f"{raw_val:.4f}".rstrip("0").rstrip(".")
855
+ if len(reward_val) > 8:
856
+ reward_val = reward_val[:8] + "…"
857
+ else:
858
+ reward_val = str(raw_val)
859
+ # Format: "🤖 Alice 💬 • Reward: 5.5556 • "
860
+ badge_inner = (
861
+ f'{emoji} <span class="agent-name">{name}</span>'
862
+ f' <span class="sep"> • </span><span class="reward">{reward_val} r</span>'
863
+ f' <span class="sep"> • </span>'
864
+ )
865
+ else:
866
+ # For user messages, show "User of {Agent ID}" in the badge
867
+ name = "User of " + html.escape(turn.agent_id)
868
+ emoji = '<span class="emoji-bw">⚙️</span>'
869
+ # Format (no reward): "⚙️ User of Alice • "
870
+ badge_inner = f'{emoji} <span class="agent-name">{name}</span><span class="sep"> • </span>'
871
+
872
+ badge = f'<span class="agent-badge">{badge_inner}</span>'
873
+
874
+ # Inline timestep distinction badge at step boundaries (render before first message)
875
+ ts_badge_html = ""
876
+ if last_time_step is None or turn.time_step != last_time_step:
877
+ ts_badge_html = f'<span class="ts-badge">⏱ {turn.time_step}</span>'
878
+ last_time_step = turn.time_step
879
+
880
+ escaped_content = html.escape(turn.content)
881
+ collapsed_text = re.sub(r"\s+", " ", escaped_content).strip()
882
+ # Optional reasoning
883
+ reasoning_val = getattr(turn, "reasoning", None)
884
+ reasoning_html = ""
885
+ if reasoning_val:
886
+ escaped_reasoning = html.escape(reasoning_val)
887
+ reasoning_text = re.sub(r"\s+", " ", escaped_reasoning).strip()
888
+ reasoning_html = (
889
+ f'<span class="segment reasoning-box collapsed{segment_collapsed_class}">'
890
+ f'<span class="emoji-bw">💭</span> '
891
+ f'<span class="seg-text"><i>{reasoning_text} </i></span>'
892
+ f"</span>"
893
+ )
894
+
895
+ html_parts.append(
896
+ f'<div class="chat-turn {agent_class} {role_class}" data-time-step="{turn.time_step}">'
897
+ f'<div class="turn-content {agent_class} {role_class}">{ts_badge_html}{badge}'
898
+ f"{reasoning_html}"
899
+ f'<span class="segment message-box{segment_collapsed_class}"><span class="emoji-bw">💬</span> <span class="seg-text">{collapsed_text}</span></span>'
900
+ f"</div>"
901
+ f"</div>"
902
+ )
903
+
904
+ html_parts.extend(["</div>", "</body>", "</html>"])
905
+
906
+ return "\n".join(html_parts)
907
+
908
+
909
+ def export_html_from_rollout_tree(path: Path, outdir: Path, main_only: bool = False):
910
+ """Process a rollout tree file and generate HTML files for each path.
911
+ Creates separate HTML files for the main path and each branch path.
912
+ The main path is saved in the root output directory, while branch paths
913
+ are saved in a 'branches' subdirectory.
914
+
915
+ Args:
916
+ path: Path to the rollout tree JSON file
917
+ outdir: Output directory for HTML files
918
+ main_only: If True, only export the main trajectory (default: False)
919
+ """
920
+ root = load_rollout_tree(path)
921
+ mgid = root.id
922
+
923
+ main_path, branch_paths = get_rollout_tree_paths(root)
924
+
925
+ outdir.mkdir(parents=True, exist_ok=True)
926
+
927
+ # Create branches subdirectory if we have branch paths
928
+ if not main_only and branch_paths:
929
+ branches_dir = outdir / f"mgid:{mgid}_branches_html_renders"
930
+ branches_dir.mkdir(parents=True, exist_ok=True)
931
+
932
+ # Generate HTML for the main path
933
+ chat_turns = gather_all_chat_turns_for_path(main_path)
934
+ html_content = html_from_chat_turns(chat_turns)
935
+ output_file = outdir / f"mgid:{mgid}_main_html_render.render.html"
936
+ with open(output_file, "w", encoding="utf-8") as f:
937
+ f.write(html_content)
938
+
939
+ # Generate HTML for each branch path
940
+ for path_obj in branch_paths:
941
+ chat_turns = gather_all_chat_turns_for_path(path_obj)
942
+
943
+ html_content = html_from_chat_turns(chat_turns)
944
+
945
+ path_id: str = path_obj.id
946
+ output_filename = f"{path_id}_html_render.render.html"
947
+
948
+ output_file = branches_dir / output_filename
949
+
950
+ with open(output_file, "w", encoding="utf-8") as f:
951
+ f.write(html_content)
src_code_for_reproducibility/markov_games/mg_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Callable
2
+ from dataclasses import dataclass
3
+ import copy
4
+ import asyncio
5
+
6
+ from mllm.markov_games.ipd.ipd_agent import IPDAgent
7
+ from mllm.markov_games.ipd.ipd_simulation import IPD
8
+ from mllm.markov_games.markov_game import MarkovGame
9
+ from mllm.markov_games.negotiation.dond_agent import DealNoDealAgent
10
+ from mllm.markov_games.negotiation.dond_simulation import DealNoDealSimulation
11
+ from mllm.markov_games.negotiation.no_press_nego_agent import NoPressAgent
12
+ from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressSimulation
13
+ from mllm.markov_games.negotiation.tas_agent import TrustAndSplitAgent
14
+ from mllm.markov_games.negotiation.tas_rps_agent import TrustAndSplitRPSAgent
15
+ from mllm.markov_games.negotiation.tas_rps_simulation import TrustAndSplitRPSSimulation
16
+ from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitSimulation
17
+
18
+ from mllm.markov_games.markov_game import MarkovGame
19
+ from mllm.markov_games.rollout_tree import RolloutTreeRootNode, StepLog, RolloutTreeBranchNode
20
+ from mllm.markov_games.rollout_tree import AgentActLog
21
+ from mllm.markov_games.simulation import SimulationStepLog
22
+ from mllm.markov_games.rollout_tree import RolloutTreeNode
23
+
24
+ AgentId = str
25
+
26
+
27
+
28
+
29
+ @dataclass
30
+ class AgentConfig:
31
+ agent_id: int
32
+ agent_class_name: str
33
+ policy_id: str
34
+ init_kwargs: dict
35
+
36
+
37
+ @dataclass
38
+ class MarkovGameConfig:
39
+ id: int
40
+ seed: int
41
+ simulation_class_name: str
42
+ simulation_init_args: dict
43
+ agent_configs: list[AgentConfig]
44
+
45
+
46
+ def init_markov_game_components(
47
+ config: MarkovGameConfig, policies: dict[str, Callable[[list[dict]], str]]
48
+ ):
49
+ """
50
+ TOWRITE
51
+ """
52
+ simulation = eval(config.simulation_class_name)(
53
+ seed=config.seed,
54
+ **config.simulation_init_args,
55
+ )
56
+ agents = {}
57
+ for agent_config in config.agent_configs:
58
+ agent_id = agent_config.agent_id
59
+ agent_class = eval(agent_config.agent_class_name)
60
+ agent = agent_class(
61
+ seed=config.seed,
62
+ agent_id=agent_id,
63
+ policy=policies[agent_config.policy_id],
64
+ **agent_config.init_kwargs,
65
+ )
66
+ agents[agent_id] = agent
67
+ markov_game = MarkovGame(
68
+ id=config.id,
69
+ crn_id=config.seed,
70
+ simulation=simulation,
71
+ agents=agents,
72
+ )
73
+ return markov_game
74
+
75
+
76
+
77
+
src_code_for_reproducibility/markov_games/simulation.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A Simulation is the environment of a Markov Game.
3
+ The Simulation is not responsible for properly checking / formatting the responses of LLM's.
4
+ This is the job of the `Agent` class.
5
+ Simulations expect clean actions, and are defined similarly to `gymnasium` environments, except that they are adapted for the Multi-agent setting.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import Any, Tuple
10
+
11
+ from numpy.random import default_rng
12
+
13
+ from mllm.markov_games.rollout_tree import SimulationStepLog
14
+
15
+
16
+ class Simulation(ABC):
17
+ @abstractmethod
18
+ def __init__(self, seed: int, *args, **kwargs):
19
+ self.seed = seed
20
+ self.rng = default_rng(self.seed)
21
+
22
+ @abstractmethod
23
+ def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]:
24
+ """
25
+ Returns terminated, info
26
+ """
27
+ raise NotImplementedError
28
+
29
+ def get_obs(self):
30
+ """Returns all agent observations in dict
31
+
32
+ Returns:
33
+ observations
34
+ """
35
+ raise NotImplementedError
36
+
37
+ def get_obs_agent(self, agent_id):
38
+ """Returns observation for agent_id"""
39
+ raise NotImplementedError
40
+
41
+ def get_obs_size(self):
42
+ """Returns the shape of the observation"""
43
+ raise NotImplementedError
44
+
45
+ def get_state(self):
46
+ raise NotImplementedError
47
+
48
+ def get_state_size(self):
49
+ """Returns the shape of the state"""
50
+ raise NotImplementedError
51
+
52
+ def get_avail_actions(self):
53
+ raise NotImplementedError
54
+
55
+ def get_avail_agent_actions(self, agent_id):
56
+ """Returns the available actions for agent_id"""
57
+ raise NotImplementedError
58
+
59
+ def get_total_actions(self):
60
+ """Returns the total number of actions an agent could ever take"""
61
+ # TODO: This is only suitable for a discrete 1 dimensional action space for each agent
62
+ raise NotImplementedError
63
+
64
+ def get_safe_copy(self):
65
+ """
66
+ Return copy of the agent object that is decorrelated from the original object.
67
+ """
68
+ raise NotImplementedError
69
+
70
+ def reset(self):
71
+ """Returns initial observations and states"""
72
+ raise NotImplementedError
73
+
74
+ def render(self):
75
+ raise NotImplementedError
76
+
77
+ def close(self):
78
+ raise NotImplementedError
79
+
80
+ # def seed(self):
81
+ # raise NotImplementedError
82
+
83
+ def save_replay(self):
84
+ raise NotImplementedError
85
+
86
+ def get_simulation_info(self):
87
+ raise NotImplementedError
src_code_for_reproducibility/markov_games/statistics_runner.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import json
5
+ import pickle
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional
9
+
10
+ from basic_render import find_iteration_folders
11
+
12
+ from mllm.markov_games.rollout_tree import (
13
+ RolloutTreeBranchNode,
14
+ RolloutTreeNode,
15
+ RolloutTreeRootNode,
16
+ SimulationStepLog,
17
+ )
18
+
19
+
20
+ def _iterate_main_nodes(root: RolloutTreeRootNode) -> Iterator[RolloutTreeNode]:
21
+ """
22
+ Iterate the main path nodes without materializing full path lists.
23
+ """
24
+ current = root.child
25
+ while current is not None:
26
+ if isinstance(current, RolloutTreeNode):
27
+ yield current
28
+ current = current.child
29
+ elif isinstance(current, RolloutTreeBranchNode):
30
+ # Follow only the main child on the main trajectory
31
+ current = current.main_child
32
+ else:
33
+ break
34
+
35
+
36
+ def iterate_main_simulation_logs(
37
+ root: RolloutTreeRootNode,
38
+ ) -> Iterator[SimulationStepLog]:
39
+ for node in _iterate_main_nodes(root):
40
+ yield node.step_log.simulation_step_log
41
+
42
+
43
+ def stream_rollout_files(iteration_folder: Path) -> Iterator[Path]:
44
+ for p in iteration_folder.rglob("*.rt.pkl"):
45
+ if p.is_file():
46
+ yield p
47
+
48
+
49
+ def load_root(path: Path) -> RolloutTreeRootNode:
50
+ with open(path, "rb") as f:
51
+ data = pickle.load(f)
52
+ return RolloutTreeRootNode.model_validate(data)
53
+
54
+
55
+ @dataclass
56
+ class StatRecord:
57
+ mgid: int
58
+ crn_id: Optional[int]
59
+ iteration: str
60
+ values: Dict[str, Any]
61
+
62
+
63
+ class StatComputer:
64
+ """
65
+ Stateful stat computer that consumes SimulationStepLog instances
66
+ and produces final aggregated values for one rollout (mgid).
67
+ """
68
+
69
+ def update(self, sl: SimulationStepLog) -> None: # pragma: no cover - interface
70
+ raise NotImplementedError
71
+
72
+ def finalize(self) -> Dict[str, Any]: # pragma: no cover - interface
73
+ raise NotImplementedError
74
+
75
+
76
+ def run_stats(
77
+ data_root: Path,
78
+ game_name: str,
79
+ make_computers: Callable[[], List[StatComputer]],
80
+ output_filename: Optional[str] = None,
81
+ output_format: str = "json", # "json" (dict of lists) or "jsonl"
82
+ ) -> Path:
83
+ """
84
+ Compute stats across all iteration_* folders under data_root.
85
+ Writes JSONL to data_root/statistics/<output_filename or f"{game_name}.stats.jsonl">.
86
+ """
87
+ data_root = Path(data_root)
88
+ outdir = data_root / "statistics"
89
+ outdir.mkdir(parents=True, exist_ok=True)
90
+ # Choose extension by format
91
+ default_name = (
92
+ f"{game_name}.stats.json"
93
+ if output_format == "json"
94
+ else f"{game_name}.stats.jsonl"
95
+ )
96
+ outfile = outdir / (
97
+ output_filename if output_filename is not None else default_name
98
+ )
99
+
100
+ # Rewrite file each run to keep it clean and small
101
+ if outfile.exists():
102
+ outfile.unlink()
103
+
104
+ iteration_folders = find_iteration_folders(str(data_root))
105
+
106
+ # If writing JSONL, stream directly; otherwise accumulate minimal records
107
+ if output_format == "jsonl":
108
+ with open(outfile, "w", encoding="utf-8") as w:
109
+ for iteration_folder in iteration_folders:
110
+ iteration_name = Path(iteration_folder).name
111
+ for pkl_path in stream_rollout_files(Path(iteration_folder)):
112
+ root = load_root(pkl_path)
113
+
114
+ computers = make_computers()
115
+ for sl in iterate_main_simulation_logs(root):
116
+ for comp in computers:
117
+ try:
118
+ comp.update(sl)
119
+ except Exception:
120
+ continue
121
+
122
+ values: Dict[str, Any] = {}
123
+ for comp in computers:
124
+ try:
125
+ values.update(comp.finalize())
126
+ except Exception:
127
+ continue
128
+
129
+ rec = {
130
+ "mgid": getattr(root, "id", None),
131
+ "crn_id": getattr(root, "crn_id", None),
132
+ "iteration": iteration_name,
133
+ "stats": values,
134
+ }
135
+ w.write(json.dumps(rec, ensure_ascii=False) + "\n")
136
+
137
+ del root
138
+ del computers
139
+ gc.collect()
140
+ else:
141
+ # Aggregate to dict-of-lists for easier plotting
142
+ records: List[Dict[str, Any]] = []
143
+ # Process in deterministic order
144
+ for iteration_folder in iteration_folders:
145
+ iteration_name = Path(iteration_folder).name
146
+ for pkl_path in stream_rollout_files(Path(iteration_folder)):
147
+ root = load_root(pkl_path)
148
+
149
+ computers = make_computers()
150
+ for sl in iterate_main_simulation_logs(root):
151
+ for comp in computers:
152
+ try:
153
+ comp.update(sl)
154
+ except Exception:
155
+ continue
156
+
157
+ values: Dict[str, Any] = {}
158
+ for comp in computers:
159
+ try:
160
+ values.update(comp.finalize())
161
+ except Exception:
162
+ continue
163
+
164
+ records.append(
165
+ {
166
+ "mgid": getattr(root, "id", None),
167
+ "crn_id": getattr(root, "crn_id", None),
168
+ "iteration": iteration_name,
169
+ "stats": values,
170
+ }
171
+ )
172
+
173
+ del root
174
+ del computers
175
+ gc.collect()
176
+
177
+ # Build dict-of-lists with nested stats preserved
178
+ # Collect all stat keys and nested agent keys where needed
179
+ mgids: List[Any] = []
180
+ crn_ids: List[Any] = []
181
+ iterations_out: List[str] = []
182
+ # stats_out is a nested structure mirroring keys but with lists
183
+ stats_out: Dict[str, Any] = {}
184
+
185
+ # First pass to collect union of keys
186
+ stat_keys: set[str] = set()
187
+ nested_agent_keys: Dict[str, set[str]] = {}
188
+ for r in records:
189
+ stats = r.get("stats", {}) or {}
190
+ for k, v in stats.items():
191
+ stat_keys.add(k)
192
+ if isinstance(v, dict):
193
+ nested = nested_agent_keys.setdefault(k, set())
194
+ for ak in v.keys():
195
+ nested.add(str(ak))
196
+
197
+ # Initialize structure
198
+ for k in stat_keys:
199
+ if k in nested_agent_keys:
200
+ stats_out[k] = {ak: [] for ak in sorted(nested_agent_keys[k])}
201
+ else:
202
+ stats_out[k] = []
203
+
204
+ # Fill lists
205
+ for r in records:
206
+ mgids.append(r.get("mgid"))
207
+ crn_ids.append(r.get("crn_id"))
208
+ iterations_out.append(r.get("iteration"))
209
+ stats = r.get("stats", {}) or {}
210
+ for k in stat_keys:
211
+ val = stats.get(k)
212
+ if isinstance(stats_out[k], dict):
213
+ # per-agent dict
214
+ agent_dict = val if isinstance(val, dict) else {}
215
+ for ak in stats_out[k].keys():
216
+ stats_out[k][ak].append(agent_dict.get(ak))
217
+ else:
218
+ stats_out[k].append(val)
219
+
220
+ with open(outfile, "w", encoding="utf-8") as w:
221
+ json.dump(
222
+ {
223
+ "mgid": mgids,
224
+ "crn_id": crn_ids,
225
+ "iteration": iterations_out,
226
+ "stats": stats_out,
227
+ },
228
+ w,
229
+ ensure_ascii=False,
230
+ )
231
+
232
+ return outfile
233
+
234
+
235
+ def run_stats_functional(
236
+ data_root: Path,
237
+ game_name: str,
238
+ metrics: Dict[str, Callable[[SimulationStepLog], Optional[Dict[str, float]]]],
239
+ output_filename: Optional[str] = None,
240
+ output_format: str = "json",
241
+ ) -> Path:
242
+ """
243
+ Functional variant where metrics is a dict of name -> f(SimulationStepLog) -> {agent_id: value}.
244
+ Aggregates per rollout by averaging over steps where a metric produced a value.
245
+ Writes a single consolidated file in data_root/statistics/.
246
+ """
247
+ data_root = Path(data_root)
248
+ outdir = data_root / "statistics"
249
+ outdir.mkdir(parents=True, exist_ok=True)
250
+ default_name = (
251
+ f"{game_name}.stats.json"
252
+ if output_format == "json"
253
+ else f"{game_name}.stats.jsonl"
254
+ )
255
+ outfile = outdir / (
256
+ output_filename if output_filename is not None else default_name
257
+ )
258
+
259
+ if outfile.exists():
260
+ outfile.unlink()
261
+
262
+ iteration_folders = find_iteration_folders(str(data_root))
263
+
264
+ def finalize_rollout(
265
+ agg: Dict[str, Dict[str, List[float]]]
266
+ ) -> Dict[str, Dict[str, float]]:
267
+ # avg per metric per agent
268
+ result: Dict[str, Dict[str, float]] = {}
269
+ for mname, agent_values in agg.items():
270
+ result[mname] = {}
271
+ for aid, vals in agent_values.items():
272
+ if not vals:
273
+ result[mname][aid] = None # keep alignment; could be None
274
+ else:
275
+ result[mname][aid] = sum(vals) / len(vals)
276
+ return result
277
+
278
+ if output_format == "jsonl":
279
+ with open(outfile, "w", encoding="utf-8") as w:
280
+ for iteration_folder in iteration_folders:
281
+ iteration_name = Path(iteration_folder).name
282
+ for pkl_path in stream_rollout_files(Path(iteration_folder)):
283
+ root = load_root(pkl_path)
284
+
285
+ # aggregator structure: metric -> agent_id -> list of values
286
+ agg: Dict[str, Dict[str, List[float]]] = {
287
+ m: {} for m in metrics.keys()
288
+ }
289
+
290
+ for sl in iterate_main_simulation_logs(root):
291
+ for mname, fn in metrics.items():
292
+ try:
293
+ vals = fn(sl)
294
+ except Exception:
295
+ vals = None
296
+ if not vals:
297
+ continue
298
+ for aid, v in vals.items():
299
+ if v is None:
300
+ continue
301
+ lst = agg[mname].setdefault(str(aid), [])
302
+ try:
303
+ lst.append(float(v))
304
+ except Exception:
305
+ continue
306
+
307
+ values = finalize_rollout(agg)
308
+ rec = {
309
+ "mgid": getattr(root, "id", None),
310
+ "crn_id": getattr(root, "crn_id", None),
311
+ "iteration": iteration_name,
312
+ "stats": values,
313
+ }
314
+ w.write(json.dumps(rec, ensure_ascii=False) + "\n")
315
+
316
+ del root
317
+ gc.collect()
318
+ else:
319
+ records: List[Dict[str, Any]] = []
320
+ for iteration_folder in iteration_folders:
321
+ iteration_name = Path(iteration_folder).name
322
+ for pkl_path in stream_rollout_files(Path(iteration_folder)):
323
+ root = load_root(pkl_path)
324
+
325
+ agg: Dict[str, Dict[str, List[float]]] = {m: {} for m in metrics.keys()}
326
+ for sl in iterate_main_simulation_logs(root):
327
+ for mname, fn in metrics.items():
328
+ try:
329
+ vals = fn(sl)
330
+ except Exception:
331
+ vals = None
332
+ if not vals:
333
+ continue
334
+ for aid, v in vals.items():
335
+ if v is None:
336
+ continue
337
+ lst = agg[mname].setdefault(str(aid), [])
338
+ try:
339
+ lst.append(float(v))
340
+ except Exception:
341
+ continue
342
+
343
+ values = finalize_rollout(agg)
344
+ records.append(
345
+ {
346
+ "mgid": getattr(root, "id", None),
347
+ "crn_id": getattr(root, "crn_id", None),
348
+ "iteration": iteration_name,
349
+ "stats": values,
350
+ }
351
+ )
352
+
353
+ del root
354
+ gc.collect()
355
+
356
+ # Build dict-of-lists output
357
+ mgids: List[Any] = []
358
+ crn_ids: List[Any] = []
359
+ iterations_out: List[str] = []
360
+ stats_out: Dict[str, Any] = {}
361
+
362
+ stat_keys: set[str] = set()
363
+ nested_agent_keys: Dict[str, set[str]] = {}
364
+ for r in records:
365
+ stats = r.get("stats", {}) or {}
366
+ for k, v in stats.items():
367
+ stat_keys.add(k)
368
+ if isinstance(v, dict):
369
+ nested = nested_agent_keys.setdefault(k, set())
370
+ for ak in v.keys():
371
+ nested.add(str(ak))
372
+
373
+ for k in stat_keys:
374
+ if k in nested_agent_keys:
375
+ stats_out[k] = {ak: [] for ak in sorted(nested_agent_keys[k])}
376
+ else:
377
+ stats_out[k] = []
378
+
379
+ for r in records:
380
+ mgids.append(r.get("mgid"))
381
+ crn_ids.append(r.get("crn_id"))
382
+ iterations_out.append(r.get("iteration"))
383
+ stats = r.get("stats", {}) or {}
384
+ for k in stat_keys:
385
+ val = stats.get(k)
386
+ if isinstance(stats_out[k], dict):
387
+ agent_dict = val if isinstance(val, dict) else {}
388
+ for ak in stats_out[k].keys():
389
+ stats_out[k][ak].append(agent_dict.get(ak))
390
+ else:
391
+ stats_out[k].append(val)
392
+
393
+ with open(outfile, "w", encoding="utf-8") as w:
394
+ json.dump(
395
+ {
396
+ "mgid": mgids,
397
+ "crn_id": crn_ids,
398
+ "iteration": iterations_out,
399
+ "stats": stats_out,
400
+ },
401
+ w,
402
+ ensure_ascii=False,
403
+ )
404
+
405
+ return outfile
src_code_for_reproducibility/models/__init__.py ADDED
File without changes
src_code_for_reproducibility/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (192 Bytes). View file
 
src_code_for_reproducibility/models/__pycache__/adapter_training_wrapper.cpython-311.pyc ADDED
Binary file (4.66 kB). View file
 
src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-311.pyc ADDED
Binary file (2.46 kB). View file
 
src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-311.pyc ADDED
Binary file (2.82 kB). View file
 
src_code_for_reproducibility/models/__pycache__/inference_backend_sglang.cpython-311.pyc ADDED
Binary file (4.13 kB). View file
 
src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-311.pyc ADDED
Binary file (5.19 kB). View file
 
src_code_for_reproducibility/models/__pycache__/large_language_model_api.cpython-311.pyc ADDED
Binary file (7.43 kB). View file
 
src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-311.pyc ADDED
Binary file (14.9 kB). View file
 
src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-311.pyc ADDED
Binary file (3.43 kB). View file
 
src_code_for_reproducibility/models/adapter_training_wrapper.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.nn as nn
3
+ import logging
4
+ from typing import Union
5
+ from peft import (
6
+ LoraConfig,
7
+ get_peft_model,
8
+ )
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class AdapterWrapper(nn.Module):
14
+ """
15
+ A thin façade that
16
+ • keeps a reference to a *shared* PEFT-wrapped model,
17
+ • ensures `set_adapter(adapter)` is called on every forward,
18
+ • exposes only the parameters that should be trained for that adapter
19
+ (plus whatever extra modules you name).
20
+ """
21
+ def __init__(
22
+ self,
23
+ shared_llm: nn.Module,
24
+ adapter_id: str,
25
+ lora_config: dict,
26
+ path: Union[str, None] = None,
27
+ ):
28
+ super().__init__()
29
+ self.shared_llm = shared_llm
30
+ self.adapter_id = adapter_id
31
+ lora_config = LoraConfig(**lora_config)
32
+ # this modifies the shared llm in place, adding a lora adapter inside
33
+ self.shared_llm = get_peft_model(
34
+ model=shared_llm,
35
+ peft_config=lora_config,
36
+ adapter_name=adapter_id,
37
+ )
38
+ self.shared_llm.train()
39
+ # Load external adapter weights if provided
40
+ loaded_from: str | None = None
41
+ if path:
42
+ try:
43
+ # Supports both local filesystem paths and HF Hub repo IDs
44
+ self.shared_llm.load_adapter(
45
+ is_trainable=True,
46
+ model_id=path,
47
+ adapter_name=adapter_id,
48
+ )
49
+ loaded_from = path
50
+ except Exception as exc: # noqa: BLE001 - want to log any load failure context
51
+ logger.warning(
52
+ f"Adapter '{adapter_id}': failed to load from '{path}': {exc}"
53
+ )
54
+
55
+ if loaded_from:
56
+ logger.info(
57
+ f"Adapter '{adapter_id}': loaded initial weights from '{loaded_from}'."
58
+ )
59
+ else:
60
+ logger.info(
61
+ f"Adapter '{adapter_id}': initialized with fresh weights (no initial weights found)."
62
+ )
63
+
64
+ def parameters(self, recurse: bool = True):
65
+ """
66
+ "recurse" is just for pytorch compatibility
67
+ """
68
+ self.shared_llm.set_adapter(self.adapter_id)
69
+ params = [p for p in self.shared_llm.parameters() if p.requires_grad]
70
+
71
+ return params
72
+
73
+ def forward(self, *args, **kwargs):
74
+ self.shared_llm.set_adapter(self.adapter_id)
75
+ return self.shared_llm(*args, **kwargs)
76
+
77
+ def save_pretrained(self, save_path):
78
+ self.shared_llm.save_pretrained(save_path)
79
+
80
+ def gradient_checkpointing_enable(self, *args, **kwargs):
81
+ self.shared_llm.gradient_checkpointing_enable(*args, **kwargs)
82
+
83
+ @property
84
+ def dtype(self):
85
+ return self.shared_llm.dtype
86
+
87
+ @property
88
+ def device(self):
89
+ return self.shared_llm.device
src_code_for_reproducibility/models/inference_backend.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Optional
3
+ from dataclasses import dataclass
4
+
5
+ @dataclass
6
+ class PolicyOutput:
7
+ content: str
8
+ reasoning_content: str | None = None
9
+
10
+ class LLMInferenceBackend(ABC):
11
+ @abstractmethod
12
+ def __init__(self, **kwargs):
13
+ ...
14
+
15
+ @abstractmethod
16
+ def prepare_adapter(
17
+ self, adapter_id: str, weights_got_updated: bool = False
18
+ ) -> None:
19
+ """Ensure adapter is ready/loaded for next generation call."""
20
+
21
+ @abstractmethod
22
+ async def generate(self, prompt: list[dict], regex: Optional[str] = None) -> PolicyOutput:
23
+ ...
24
+
25
+ @abstractmethod
26
+ def toggle_training_mode(self) -> None:
27
+ ...
28
+
29
+ @abstractmethod
30
+ def toggle_eval_mode(self) -> None:
31
+ ...
32
+
33
+ @abstractmethod
34
+ def shutdown(self) -> None:
35
+ ...
src_code_for_reproducibility/models/inference_backend_dummy.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import re
3
+ from typing import Optional
4
+
5
+ import rstr
6
+ from transformers import AutoTokenizer
7
+
8
+ from mllm.models.inference_backend import LLMInferenceBackend, PolicyOutput
9
+ from mllm.utils.short_id_gen import generate_short_id
10
+
11
+
12
+ class DummyInferenceBackend(LLMInferenceBackend):
13
+ def __init__(
14
+ self,
15
+ *args,
16
+ **kwargs,
17
+ ):
18
+ pass
19
+
20
+ def prepare_adapter(
21
+ self, adapter_id: Optional[str], weights_got_updated: bool
22
+ ) -> None:
23
+ pass
24
+
25
+ async def toggle_training_mode(self) -> None:
26
+ await asyncio.sleep(0)
27
+ pass
28
+
29
+ async def toggle_eval_mode(self) -> None:
30
+ await asyncio.sleep(0)
31
+ pass
32
+
33
+ def shutdown(self) -> None:
34
+ pass
35
+
36
+ async def generate(
37
+ self, prompt_text: str, regex: Optional[str] = None
38
+ ) -> PolicyOutput:
39
+ content = "I am a dummy backend without a regex."
40
+ reasoning_content = None
41
+
42
+ if regex:
43
+ raw_text = rstr.xeger(regex)
44
+ content = raw_text
45
+ # Strict split: require \n<think>...</think>\n\n before final content
46
+ m = re.match(
47
+ r"^\n<think>\n([\s\S]*?)</think>\n\n(.*)$", raw_text, flags=re.DOTALL
48
+ )
49
+ if m:
50
+ reasoning_content = m.group(1)
51
+ content = m.group(2)
52
+
53
+ return PolicyOutput(content=content, reasoning_content=reasoning_content)
src_code_for_reproducibility/models/inference_backend_sglang.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # new_backend_sglang_offline.py
2
+ from __future__ import annotations
3
+
4
+ import asyncio
5
+ from typing import Any, Optional
6
+
7
+ import sglang as sgl
8
+
9
+ from mllm.models.inference_backend import LLMInferenceBackend
10
+
11
+
12
+ class SGLangOfflineBackend(LLMInferenceBackend):
13
+ def __init__(
14
+ self,
15
+ model_name: str,
16
+ tokenizer, # unused but kept for parity
17
+ adapter_paths: dict[str, str],
18
+ device: str = "cuda",
19
+ max_model_len: Optional[int] = None,
20
+ enable_lora: bool = True,
21
+ lora_target_modules: Optional[list[str] | str] = None,
22
+ max_loras_per_batch: int = 8,
23
+ engine_kwargs: dict[str, Any] = None,
24
+ ):
25
+ self.model_name = model_name
26
+ self.adapter_paths = adapter_paths
27
+ self.current_adapter: Optional[str] = None
28
+ engine_kwargs = dict(engine_kwargs or {})
29
+ # Map server-style LoRA flags to offline engine ctor
30
+ if enable_lora and adapter_paths:
31
+ engine_kwargs.setdefault("enable_lora", True)
32
+ # The offline Engine mirrors server args; pass a mapping name->path
33
+ engine_kwargs.setdefault("lora_paths", adapter_paths)
34
+ if lora_target_modules is not None:
35
+ engine_kwargs.setdefault("lora_target_modules", lora_target_modules)
36
+ engine_kwargs.setdefault("max_loras_per_batch", max_loras_per_batch)
37
+
38
+ if max_model_len is not None:
39
+ engine_kwargs.setdefault("context_length", max_model_len)
40
+
41
+ # Launch in-process engine (no HTTP server)
42
+ self.llm = sgl.Engine(model_path=model_name, **engine_kwargs) # async-ready
43
+ # SGLang supports: generate(), async_generate(), and async streaming helpers. :contentReference[oaicite:2]{index=2}
44
+
45
+ def is_ready(self) -> bool:
46
+ return True
47
+
48
+ def toggle_training_mode(self) -> None:
49
+ # No explicit KV release API offline; typically you pause usage here.
50
+ pass
51
+
52
+ def toggle_eval_mode(self) -> None:
53
+ pass
54
+
55
+ def shutdown(self) -> None:
56
+ # Engine cleans up on GC; explicit close not required.
57
+ pass
58
+
59
+ def prepare_adapter(self, adapter_id: Optional[str]) -> None:
60
+ # With offline Engine, when LoRA is enabled at init,
61
+ # you select adapter per request via the input batch mapping.
62
+ self.current_adapter = adapter_id
63
+
64
+ async def generate(
65
+ self, prompt_text: str, sampling_params: dict, adapter_id: Optional[str]
66
+ ) -> str:
67
+ # Non-streaming async (batch of 1). For batched prompts, pass a list.
68
+ params = {
69
+ "temperature": sampling_params.get("temperature", 1.0),
70
+ "top_p": sampling_params.get("top_p", 1.0),
71
+ "max_new_tokens": sampling_params.get("max_new_tokens", 128),
72
+ }
73
+ if (tk := sampling_params.get("top_k", -1)) and tk > 0:
74
+ params["top_k"] = tk
75
+ if (mn := sampling_params.get("min_new_tokens")) is not None:
76
+ params["min_new_tokens"] = mn
77
+ if (fp := sampling_params.get("frequency_penalty")) is not None:
78
+ params["frequency_penalty"] = fp
79
+
80
+ # If using multi-LoRA, SGLang lets you provide adapter names aligned to each input.
81
+ prompts = [prompt_text]
82
+ adapters = [adapter_id] if adapter_id else None # or omit for base
83
+ outs = await self.llm.async_generate(
84
+ prompts, params, adapters
85
+ ) # :contentReference[oaicite:3]{index=3}
86
+ return outs[0]["text"]
src_code_for_reproducibility/models/inference_backend_sglang_local_server.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import httpx
4
+ import requests
5
+ from sglang.utils import launch_server_cmd, wait_for_server
6
+
7
+ from mllm.models.inference_backend import LLMInferenceBackend
8
+
9
+
10
+ class HttpSGLangBackend(LLMInferenceBackend):
11
+ def __init__(self, **kwargs):
12
+ super().__init__(**kwargs)
13
+ self.port = None
14
+ self.proc = None
15
+ self.urls = {}
16
+ # track sglang adapter ids separately from your logical ids
17
+ self.sglang_names = {aid: aid for aid in self.adapter_paths.keys()}
18
+ self.needs_loading = {aid: True for aid in self.adapter_paths.keys()}
19
+
20
+ # defaults you already used:
21
+ self.mem_fraction = kwargs.get("mem_fraction_static", 0.6)
22
+ self.dtype = kwargs.get("dtype", "bfloat16")
23
+ self.extra_cli = kwargs.get("extra_cli", "")
24
+ self.disable_radix_cache = kwargs.get("disable_radix_cache", True)
25
+
26
+ def launch(self) -> None:
27
+ # find local hf cache path for server
28
+ from transformers.utils import cached_file
29
+
30
+ local_llm_path = os.path.split(cached_file(self.model_name, "config.json"))[0]
31
+
32
+ lora_str = ""
33
+ if self.adapter_paths:
34
+ lora_str = "--lora-paths " + " ".join(
35
+ f"{aid}={path}" for aid, path in self.adapter_paths.items()
36
+ )
37
+
38
+ cmd = f"""
39
+ python3 -m sglang.launch_server --model-path {local_llm_path} \
40
+ --host 0.0.0.0 {lora_str} \
41
+ {'--disable-radix-cache' if self.disable_radix_cache else ''} \
42
+ --mem-fraction-static {self.mem_fraction} --dtype {self.dtype} {self.extra_cli}
43
+ """
44
+ self.proc, self.port = launch_server_cmd(cmd)
45
+ wait_for_server(f"http://localhost:{self.port}")
46
+ base = f"http://localhost:{self.port}"
47
+ self.urls = dict(
48
+ generate=f"{base}/generate",
49
+ release=f"{base}/release_memory_occupation",
50
+ resume=f"{base}/resume_memory_occupation",
51
+ load_lora=f"{base}/load_lora_adapter",
52
+ unload_lora=f"{base}/unload_lora_adapter",
53
+ )
54
+
55
+ def is_ready(self) -> bool:
56
+ try:
57
+ requests.get(self.urls["generate"], timeout=2)
58
+ return True
59
+ except Exception:
60
+ return False
61
+
62
+ def prepare_adapter(self, adapter_id: str) -> None:
63
+ if adapter_id is None:
64
+ return
65
+ if self.needs_loading.get(adapter_id, False):
66
+ # unload old name if present
67
+ try:
68
+ requests.post(
69
+ self.urls["unload_lora"],
70
+ json={"lora_name": self.sglang_names[adapter_id]},
71
+ timeout=10,
72
+ )
73
+ except Exception:
74
+ pass
75
+ new_name = self._short_id()
76
+ self.sglang_names[adapter_id] = new_name
77
+ requests.post(
78
+ self.urls["load_lora"],
79
+ json={
80
+ "lora_name": new_name,
81
+ "lora_path": self.adapter_paths[adapter_id],
82
+ },
83
+ ).raise_for_status()
84
+ self.needs_loading[adapter_id] = False
85
+
86
+ async def generate(
87
+ self, prompt_text: str, sampling_params: dict, adapter_id: str | None
88
+ ) -> str:
89
+ lora_name = self.sglang_names.get(adapter_id) if adapter_id else None
90
+ payload = {
91
+ "text": [prompt_text],
92
+ "sampling_params": sampling_params,
93
+ }
94
+ if lora_name:
95
+ payload["lora_path"] = [lora_name]
96
+
97
+ timeout = httpx.Timeout(3600.0, connect=3600.0)
98
+ async with httpx.AsyncClient(timeout=timeout) as client:
99
+ resp = await client.post(self.urls["generate"], json=payload)
100
+ resp.raise_for_status()
101
+ return resp.json()[0]["text"]
102
+
103
+ def toggle_training_mode(self) -> None:
104
+ # free KV space while training adapters
105
+ requests.post(
106
+ self.urls["release"], json={"tags": ["kv_cache"]}
107
+ ).raise_for_status()
108
+
109
+ def toggle_eval_mode(self) -> None:
110
+ # re-allocate KV space
111
+ try:
112
+ requests.post(
113
+ self.urls["resume"], json={"tags": ["kv_cache"]}
114
+ ).raise_for_status()
115
+ except Exception:
116
+ pass
117
+
118
+ def shutdown(self) -> None:
119
+ from sglang.utils import terminate_process
120
+
121
+ if self.proc:
122
+ terminate_process(self.proc)
123
+
124
+ def _short_id(self) -> str:
125
+ import uuid
126
+
127
+ return str(uuid.uuid4().int)[:8]
src_code_for_reproducibility/models/inference_backend_vllm.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import re
3
+ from typing import Optional
4
+
5
+ from transformers import AutoTokenizer
6
+ from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
7
+ from vllm.lora.request import LoRARequest
8
+ from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind
9
+
10
+ from mllm.models.inference_backend import LLMInferenceBackend, PolicyOutput
11
+ from mllm.utils.short_id_gen import generate_short_id
12
+
13
+
14
+ class VLLMAsyncBackend(LLMInferenceBackend):
15
+ def __init__(
16
+ self,
17
+ model_name: str,
18
+ tokenizer: AutoTokenizer,
19
+ adapter_paths: dict[str, str],
20
+ engine_init_kwargs: dict = {},
21
+ sampling_params: dict = {},
22
+ ):
23
+ self.model_name = model_name
24
+ self.adapter_paths = adapter_paths or {}
25
+ self.current_adapter = None
26
+ self.vllm_adapter_ids = {
27
+ adapter_id: generate_short_id() for adapter_id in adapter_paths.keys()
28
+ }
29
+ ea = dict(model=model_name, **engine_init_kwargs)
30
+ ea["enable_lora"] = True
31
+ ea["max_loras"] = len(self.vllm_adapter_ids)
32
+ ea["enable_sleep_mode"] = True
33
+ self.engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**ea))
34
+
35
+ self.sampling_params = sampling_params
36
+
37
+ def prepare_adapter(
38
+ self, adapter_id: Optional[str], weights_got_updated: bool
39
+ ) -> None:
40
+ self.current_adapter = adapter_id
41
+ if weights_got_updated:
42
+ self.vllm_adapter_ids[adapter_id] = generate_short_id()
43
+ self.current_lora_request = LoRARequest(
44
+ adapter_id,
45
+ self.vllm_adapter_ids[adapter_id],
46
+ self.adapter_paths[adapter_id],
47
+ )
48
+
49
+ async def toggle_training_mode(self) -> None:
50
+ await self.engine.sleep(level=1)
51
+
52
+ async def toggle_eval_mode(self) -> None:
53
+ await self.engine.wake_up()
54
+
55
+ def shutdown(self) -> None:
56
+ # No explicit close call; engine stops when process exits.
57
+ pass
58
+
59
+ async def generate(
60
+ self, prompt_text: str, regex: Optional[str] = None
61
+ ) -> PolicyOutput:
62
+ # Build SamplingParams correctly
63
+
64
+ guided = GuidedDecodingParams(regex=regex) if regex else None
65
+ sp = SamplingParams(
66
+ **self.sampling_params,
67
+ guided_decoding=guided,
68
+ output_kind=RequestOutputKind.FINAL_ONLY,
69
+ )
70
+
71
+ request_id = f"req-{asyncio.get_running_loop().time()}"
72
+ result_generator = self.engine.generate(
73
+ prompt_text,
74
+ sp, # SamplingParams(...)
75
+ request_id,
76
+ lora_request=self.current_lora_request,
77
+ )
78
+
79
+ async for out in result_generator: # with FINAL_ONLY this runs once
80
+ res = out
81
+
82
+ raw_text = res.outputs[0].text
83
+
84
+ content = raw_text
85
+ reasoning_content = None
86
+
87
+ if regex:
88
+ # Strict split: require \n<think>...</think>\n\n before final content
89
+ m = re.match(
90
+ r"^\n<think>\n([\s\S]*?)</think>\n\n(.*)$", raw_text, flags=re.DOTALL
91
+ )
92
+ if m:
93
+ reasoning_content = m.group(1)
94
+ content = m.group(2)
95
+
96
+ return PolicyOutput(content=content, reasoning_content=reasoning_content)
src_code_for_reproducibility/models/inference_backend_vllm_local_server.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import subprocess
4
+ import time
5
+
6
+ import httpx
7
+ import requests
8
+
9
+ from mllm.models.inference_backend import LLMInferenceBackend
10
+
11
+
12
+ class HttpVLLMBackend(LLMInferenceBackend):
13
+ def __init__(self, **kwargs):
14
+ super().__init__(**kwargs)
15
+ self.port = kwargs.get("port", 8000)
16
+ self.host = kwargs.get("host", "0.0.0.0")
17
+ self.proc = None
18
+ self.base_url = f"http://{self.host}:{self.port}"
19
+ # vLLM memory safety knobs
20
+ self.gpu_mem_util = kwargs.get("gpu_memory_utilization", 0.9)
21
+ self.max_model_len = kwargs.get("max_model_len", None)
22
+ self.max_num_seqs = kwargs.get("max_num_seqs", None)
23
+ self.max_batched_tokens = kwargs.get("max_num_batched_tokens", None)
24
+ self.dtype = kwargs.get("dtype", "bfloat16")
25
+ self.trust_remote_code = kwargs.get("trust_remote_code", False)
26
+ # LoRA strategy: "preload" (CLI) or "runtime" (endpoints) depending on your vLLM build
27
+ self.lora_mode = kwargs.get(
28
+ "lora_mode", "preload"
29
+ ) # "runtime" supported in newer builds
30
+ self.runtime_lora_enabled = self.lora_mode == "runtime"
31
+
32
+ # If preloading: build CLI args (adapter name -> path)
33
+ self._preload_lora_args = []
34
+ if self.adapter_paths and self.lora_mode == "preload":
35
+ # vLLM supports multiple LoRA modules via CLI in recent versions
36
+ # Example flag shapes can vary; adapt as needed for your version:
37
+ # --lora-modules adapter_id=path
38
+ for aid, pth in self.adapter_paths.items():
39
+ self._preload_lora_args += ["--lora-modules", f"{aid}={pth}"]
40
+
41
+ def launch(self):
42
+ # Build vLLM serve command
43
+ cmd = [
44
+ "python3",
45
+ "-m",
46
+ "vllm.entrypoints.openai.api_server",
47
+ "--model",
48
+ self.model_name,
49
+ "--host",
50
+ self.host,
51
+ "--port",
52
+ str(self.port),
53
+ "--dtype",
54
+ self.dtype,
55
+ "--gpu-memory-utilization",
56
+ str(self.gpu_mem_util),
57
+ ]
58
+ if self.trust_remote_code:
59
+ cmd += ["--trust-remote-code"]
60
+ if self.max_model_len:
61
+ cmd += ["--max-model-len", str(self.max_model_len)]
62
+ if self.max_num_seqs:
63
+ cmd += ["--max-num-seqs", str(self.max_num_seqs)]
64
+ if self.max_batched_tokens:
65
+ cmd += ["--max-num-batched-tokens", str(self.max_batched_tokens)]
66
+ cmd += self._preload_lora_args
67
+
68
+ self.proc = subprocess.Popen(
69
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
70
+ )
71
+ self._wait_ready()
72
+
73
+ def _wait_ready(self, timeout=120):
74
+ url = f"{self.base_url}/v1/models"
75
+ t0 = time.time()
76
+ while time.time() - t0 < timeout:
77
+ try:
78
+ r = requests.get(url, timeout=2)
79
+ if r.status_code == 200:
80
+ return
81
+ except Exception:
82
+ pass
83
+ time.sleep(1)
84
+ raise RuntimeError("vLLM server did not become ready in time")
85
+
86
+ def is_ready(self) -> bool:
87
+ try:
88
+ return (
89
+ requests.get(f"{self.base_url}/v1/models", timeout=2).status_code == 200
90
+ )
91
+ except Exception:
92
+ return False
93
+
94
+ def prepare_adapter(self, adapter_id: str) -> None:
95
+ if not adapter_id or not self.runtime_lora_enabled:
96
+ return
97
+ # Newer vLLM builds expose runtime LoRA endpoints. If yours differs,
98
+ # adjust the path/body here and keep the interface stable.
99
+ try:
100
+ requests.post(
101
+ f"{self.base_url}/v1/load_lora_adapter",
102
+ json={
103
+ "adapter_name": adapter_id,
104
+ "adapter_path": self.adapter_paths[adapter_id],
105
+ },
106
+ timeout=10,
107
+ ).raise_for_status()
108
+ except Exception as e:
109
+ # If already loaded or endpoint not present, swallow or log
110
+ pass
111
+
112
+ async def generate(
113
+ self, prompt_text: str, sampling_params: dict, adapter_id: str | None
114
+ ) -> str:
115
+ # Map your sampling params to OpenAI schema
116
+ body = {
117
+ "model": self.model_name,
118
+ "messages": [{"role": "user", "content": prompt_text}],
119
+ "temperature": sampling_params.get("temperature", 1.0),
120
+ "top_p": sampling_params.get("top_p", 1.0),
121
+ "max_tokens": sampling_params.get("max_new_tokens", 128),
122
+ }
123
+ # Optional knobs:
124
+ if sampling_params.get("top_k", -1) and sampling_params["top_k"] > 0:
125
+ # vLLM accepts top_k via extra params; put under "extra_body"
126
+ body.setdefault("extra_body", {})["top_k"] = sampling_params["top_k"]
127
+ if sampling_params.get("min_new_tokens", None) is not None:
128
+ body.setdefault("extra_body", {})["min_tokens"] = sampling_params[
129
+ "min_new_tokens"
130
+ ]
131
+ if sampling_params.get("frequency_penalty", None) is not None:
132
+ body["frequency_penalty"] = sampling_params["frequency_penalty"]
133
+
134
+ # Select LoRA adapter
135
+ if adapter_id:
136
+ if self.runtime_lora_enabled:
137
+ body.setdefault("extra_body", {})["lora_adapter"] = adapter_id
138
+ else:
139
+ # when preloaded via CLI, most builds select by name via "adapter_name"/"lora_adapter"
140
+ body.setdefault("extra_body", {})["lora_adapter"] = adapter_id
141
+
142
+ url = f"{self.base_url}/v1/chat/completions"
143
+ timeout = httpx.Timeout(3600.0, connect=3600.0)
144
+ async with httpx.AsyncClient(timeout=timeout) as client:
145
+ resp = await client.post(url, json=body)
146
+ resp.raise_for_status()
147
+ data = resp.json()
148
+ return data["choices"][0]["message"]["content"]
149
+
150
+ def toggle_training_mode(self) -> None:
151
+ # vLLM doesn’t expose an explicit KV “release” toggle via API.
152
+ # Strategy: keep inference server idle during training, or run training in a separate process.
153
+ pass
154
+
155
+ def toggle_eval_mode(self) -> None:
156
+ pass
157
+
158
+ def shutdown(self) -> None:
159
+ if self.proc:
160
+ self.proc.terminate()