Muqeeth commited on
Commit
5fb294e
·
verified ·
1 Parent(s): f41f3d1

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. run.log +0 -0
  2. src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc +0 -0
  3. src_code_for_reproducibility/chat_utils/apply_template.py +12 -1
  4. src_code_for_reproducibility/chat_utils/chat_turn.py +5 -0
  5. src_code_for_reproducibility/chat_utils/template_specific.py +27 -0
  6. src_code_for_reproducibility/markov_games/__init__.py +4 -0
  7. src_code_for_reproducibility/markov_games/agent.py +18 -22
  8. src_code_for_reproducibility/markov_games/alternative_actions_runner.py +19 -11
  9. src_code_for_reproducibility/markov_games/group_timesteps.py +3 -20
  10. src_code_for_reproducibility/markov_games/linear_runner.py +13 -1
  11. src_code_for_reproducibility/markov_games/markov_game.py +35 -26
  12. src_code_for_reproducibility/markov_games/mg_utils.py +16 -8
  13. src_code_for_reproducibility/markov_games/negotiation/dond_simulation.py +34 -11
  14. src_code_for_reproducibility/markov_games/negotiation/nego_simulation.py +14 -3
  15. src_code_for_reproducibility/markov_games/negotiation/tas_agent.py +10 -0
  16. src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py +10 -0
  17. src_code_for_reproducibility/markov_games/rollout_tree.py +10 -1
  18. src_code_for_reproducibility/markov_games/run_markov_games.py +11 -0
  19. src_code_for_reproducibility/markov_games/simulation.py +25 -18
  20. src_code_for_reproducibility/markov_games/statistics_runner.py +10 -0
  21. src_code_for_reproducibility/models/__init__.py +4 -0
  22. src_code_for_reproducibility/models/__pycache__/human_policy.cpython-312.pyc +0 -0
  23. src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-312.pyc +0 -0
  24. src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc +0 -0
  25. src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-312.pyc +0 -0
  26. src_code_for_reproducibility/models/__pycache__/large_language_model_api.cpython-312.pyc +0 -0
  27. src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-312.pyc +0 -0
  28. src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc +0 -0
  29. src_code_for_reproducibility/models/adapter_training_wrapper.py +14 -8
  30. src_code_for_reproducibility/models/human_policy.py +5 -0
  31. src_code_for_reproducibility/models/inference_backend.py +5 -0
  32. src_code_for_reproducibility/models/inference_backend_dummy.py +5 -0
  33. src_code_for_reproducibility/models/inference_backend_vllm.py +6 -12
  34. src_code_for_reproducibility/models/large_language_model_api.py +7 -4
  35. src_code_for_reproducibility/models/large_language_model_local.py +8 -31
  36. src_code_for_reproducibility/models/scalar_critic.py +14 -9
  37. src_code_for_reproducibility/training/__init__.py +4 -0
  38. src_code_for_reproducibility/training/__pycache__/__init__.cpython-312.pyc +0 -0
  39. src_code_for_reproducibility/training/__pycache__/annealing_methods.cpython-312.pyc +0 -0
  40. src_code_for_reproducibility/training/__pycache__/credit_methods.cpython-312.pyc +0 -0
  41. src_code_for_reproducibility/training/__pycache__/tally_metrics.cpython-312.pyc +0 -0
  42. src_code_for_reproducibility/training/__pycache__/tally_rollout.cpython-312.pyc +0 -0
  43. src_code_for_reproducibility/training/__pycache__/tally_tokenwise.cpython-312.pyc +0 -0
  44. src_code_for_reproducibility/training/__pycache__/tokenize_chats.cpython-312.pyc +0 -0
  45. src_code_for_reproducibility/training/__pycache__/trainer_ad_align.cpython-312.pyc +0 -0
  46. src_code_for_reproducibility/training/__pycache__/trainer_common.cpython-312.pyc +0 -0
  47. src_code_for_reproducibility/training/__pycache__/trainer_independent.cpython-312.pyc +0 -0
  48. src_code_for_reproducibility/training/__pycache__/trainer_sum_rewards.cpython-312.pyc +0 -0
  49. src_code_for_reproducibility/training/__pycache__/training_data_utils.cpython-312.pyc +0 -0
  50. src_code_for_reproducibility/training/annealing_methods.py +15 -1
run.log CHANGED
The diff for this file is too large to render. See raw diff
 
src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc and b/src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc differ
 
src_code_for_reproducibility/chat_utils/apply_template.py CHANGED
@@ -1,10 +1,17 @@
 
 
 
 
 
1
  import torch
2
 
3
  from mllm.chat_utils.chat_turn import ChatTurn
4
  from mllm.chat_utils.template_specific import (
 
5
  custom_llama3_template,
6
  custom_qwen2_template,
7
  custom_qwen3_template,
 
8
  qwen2_assistant_postfix,
9
  qwen3_assistant_postfix,
10
  )
@@ -20,6 +27,8 @@ def get_custom_chat_template(tokenizer) -> str:
20
  return custom_llama3_template
21
  elif "qwen3" in tokenizer.name_or_path.lower():
22
  return custom_qwen3_template
 
 
23
  else:
24
  raise ValueError(f"Tokenizer {tokenizer.name_or_path} not supported")
25
 
@@ -32,13 +41,15 @@ def get_custom_assistant_postfix(tokenizer) -> torch.Tensor:
32
  return qwen2_assistant_postfix
33
  elif "qwen3" in tokenizer.name_or_path.lower():
34
  return qwen3_assistant_postfix
 
 
35
  return torch.tensor([], dtype=torch.long)
36
 
37
 
38
  def tokenize_chats(chats: list[ChatTurn], tokenizer, enable_thinking) -> None:
39
  """
40
  Set the chat_template_token_ids for each chat turn.
41
- # TODO: use engine tokens if available
42
  """
43
  custom_template = get_custom_chat_template(tokenizer)
44
  custom_assistant_postfix: torch.Tensor = get_custom_assistant_postfix(tokenizer)
 
1
+ """
2
+ File: mllm/chat_utils/apply_template.py
3
+ Summary: Applies tokenizer-specific chat templates and stitches chat token IDs.
4
+ """
5
+
6
  import torch
7
 
8
  from mllm.chat_utils.chat_turn import ChatTurn
9
  from mllm.chat_utils.template_specific import (
10
+ custom_gemma3_template,
11
  custom_llama3_template,
12
  custom_qwen2_template,
13
  custom_qwen3_template,
14
+ gemma3_assistant_postfix,
15
  qwen2_assistant_postfix,
16
  qwen3_assistant_postfix,
17
  )
 
27
  return custom_llama3_template
28
  elif "qwen3" in tokenizer.name_or_path.lower():
29
  return custom_qwen3_template
30
+ elif "gemma" in tokenizer.name_or_path.lower():
31
+ return custom_gemma3_template
32
  else:
33
  raise ValueError(f"Tokenizer {tokenizer.name_or_path} not supported")
34
 
 
41
  return qwen2_assistant_postfix
42
  elif "qwen3" in tokenizer.name_or_path.lower():
43
  return qwen3_assistant_postfix
44
+ elif "gemma" in tokenizer.name_or_path.lower():
45
+ return gemma3_assistant_postfix
46
  return torch.tensor([], dtype=torch.long)
47
 
48
 
49
  def tokenize_chats(chats: list[ChatTurn], tokenizer, enable_thinking) -> None:
50
  """
51
  Set the chat_template_token_ids for each chat turn.
52
+ We rely on tokenizer-side templates because engine-provided cached tokens are not exposed yet.
53
  """
54
  custom_template = get_custom_chat_template(tokenizer)
55
  custom_assistant_postfix: torch.Tensor = get_custom_assistant_postfix(tokenizer)
src_code_for_reproducibility/chat_utils/chat_turn.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import json
 
1
+ """
2
+ File: mllm/chat_utils/chat_turn.py
3
+ Summary: Defines the ChatTurn schema plus helpers for serialization and validation.
4
+ """
5
+
6
  from __future__ import annotations
7
 
8
  import json
src_code_for_reproducibility/chat_utils/template_specific.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import huggingface_hub
2
  import torch
3
  from transformers import AutoTokenizer
@@ -25,6 +30,11 @@ qwen3_assistant_postfix = (
25
  .encode("\n", return_tensors="pt")
26
  .flatten()
27
  )
 
 
 
 
 
28
  custom_qwen2_template = """
29
  {%- if add_system_prompt %}
30
  {{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }}
@@ -85,3 +95,20 @@ custom_qwen3_template = """
85
  {%- endif %}
86
  {%- endif %}
87
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: mllm/chat_utils/template_specific.py
3
+ Summary: Stores chat template variants and assistant postfix tensors per tokenizer.
4
+ """
5
+
6
  import huggingface_hub
7
  import torch
8
  from transformers import AutoTokenizer
 
30
  .encode("\n", return_tensors="pt")
31
  .flatten()
32
  )
33
+ gemma3_assistant_postfix = (
34
+ AutoTokenizer.from_pretrained("google/gemma-3-4b-it")
35
+ .encode("\n", return_tensors="pt")
36
+ .flatten()
37
+ )
38
  custom_qwen2_template = """
39
  {%- if add_system_prompt %}
40
  {{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }}
 
95
  {%- endif %}
96
  {%- endif %}
97
  """
98
+
99
+ custom_gemma3_template = """
100
+ {%- if add_system_prompt %}
101
+ {{- bos_token -}}
102
+ {%- endif %}
103
+ {%- for message in messages -%}
104
+ {%- if message['role'] == 'assistant' -%}
105
+ {%- set role = 'model' -%}
106
+ {%- else -%}
107
+ {%- set role = message['role'] -%}
108
+ {%- endif -%}
109
+ {{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}
110
+ {%- endfor -%}
111
+ {%- if add_generation_prompt -%}
112
+ {{ '<start_of_turn>model\n' }}
113
+ {%- endif -%}
114
+ """
src_code_for_reproducibility/markov_games/__init__.py CHANGED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ File: mllm/markov_games/__init__.py
3
+ Summary: Makes Markov-game subpackages importable from the top-level namespace.
4
+ """
src_code_for_reproducibility/markov_games/agent.py CHANGED
@@ -1,11 +1,6 @@
1
  """
2
- In simple RL paradise, where the action dimensions are constant and well defined,
3
- Agent classes are not necessary. But in MARL, with LLM's, there isn't always
4
- a direct path from policy to action. For instance, from the observation of the environment,
5
- a prompt must be created. Then, the outputs of the policy might be incorrect, so a second
6
- request to the LLM must be sent before the action is well defined. This is why this Agent class exists.
7
- It acts as a mini environment, bridging the gap between the core simulation and
8
- the LLM policies.
9
  """
10
 
11
  from abc import ABC, abstractmethod
@@ -18,6 +13,8 @@ from mllm.markov_games.rollout_tree import AgentActLog
18
 
19
 
20
  class Agent(ABC):
 
 
21
  @abstractmethod
22
  def __init__(
23
  self,
@@ -29,7 +26,10 @@ class Agent(ABC):
29
  **kwargs,
30
  ):
31
  """
32
- Initialize the agent state.
 
 
 
33
  """
34
  self.seed = seed
35
  self.agent_id = agent_id
@@ -40,37 +40,33 @@ class Agent(ABC):
40
 
41
  async def act(self, observation) -> Tuple[Any, AgentActLog]:
42
  """
43
- Query (possibly multiple times) a policy (or possibly a pool of policies) to
44
- obtain the action of the agent.
45
 
46
- Example:
47
- action = None
48
- prompt = self.observation_to_prompt(observation)
49
- while not self.valid(action):
50
- output = await self.policy.generate(prompt)
51
- action = self.policy_output_to_action(output)
52
- return action
53
-
54
- Returns:
55
- action
56
- step_info
57
  """
58
  raise NotImplementedError
59
 
60
  def get_safe_copy(self):
61
  """
62
- Return copy of the agent object that is decorrelated from the original object.
 
 
63
  """
64
  raise NotImplementedError
65
 
66
  def reset(self):
 
67
  raise NotImplementedError
68
 
69
  def render(self):
 
70
  raise NotImplementedError
71
 
72
  def close(self):
 
73
  raise NotImplementedError
74
 
75
  def get_agent_info(self):
 
76
  raise NotImplementedError
 
1
  """
2
+ File: mllm/markov_games/agent.py
3
+ Summary: Declares the base Agent interface connecting simulations to policy calls.
 
 
 
 
 
4
  """
5
 
6
  from abc import ABC, abstractmethod
 
13
 
14
 
15
  class Agent(ABC):
16
+ """Abstract policy wrapper that bridges simulations with arbitrary backends."""
17
+
18
  @abstractmethod
19
  def __init__(
20
  self,
 
26
  **kwargs,
27
  ):
28
  """
29
+ Initialize the agent state and seed its RNG.
30
+
31
+ Subclasses typically store extra handles (tokenizers, inference clients, etc.)
32
+ but they should always call ``super().__init__`` so sampling remains reproducible.
33
  """
34
  self.seed = seed
35
  self.agent_id = agent_id
 
40
 
41
  async def act(self, observation) -> Tuple[Any, AgentActLog]:
42
  """
43
+ Produce the next action (and associated chat log) given an environment observation.
 
44
 
45
+ Implementations can iterate with rejection sampling, multi-call deliberation, etc.
46
+ Returns both the chosen action and an `AgentActLog` describing how it was produced.
 
 
 
 
 
 
 
 
 
47
  """
48
  raise NotImplementedError
49
 
50
  def get_safe_copy(self):
51
  """
52
+ Return a deep copy whose future calls do not mutate the original agent.
53
+
54
+ Needed for branch exploration/reruns with alternative actions.
55
  """
56
  raise NotImplementedError
57
 
58
  def reset(self):
59
+ """Reset any internal state between rollouts."""
60
  raise NotImplementedError
61
 
62
  def render(self):
63
+ """Optional human-readable visualization of the agent (CLI/UI)."""
64
  raise NotImplementedError
65
 
66
  def close(self):
67
+ """Release any external resources (network sockets, subprocesses, etc.)."""
68
  raise NotImplementedError
69
 
70
  def get_agent_info(self):
71
+ """Return diagnostic metadata to embed inside rollout logs."""
72
  raise NotImplementedError
src_code_for_reproducibility/markov_games/alternative_actions_runner.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import asyncio
2
  import copy
3
  import json
@@ -16,7 +21,6 @@ from mllm.markov_games.rollout_tree import (
16
  AgentId = str
17
 
18
 
19
-
20
  async def run_with_unilateral_alt_action(
21
  markov_game: MarkovGame,
22
  agent_id: AgentId,
@@ -25,7 +29,11 @@ async def run_with_unilateral_alt_action(
25
  max_depth: int,
26
  ):
27
  """
28
- This function is used to generate a new branch for a given agent.
 
 
 
 
29
  """
30
 
31
  # Generate alternative action and take a step
@@ -65,20 +73,20 @@ async def AlternativeActionsRunner(
65
  branch_only_on_new_round: bool = False,
66
  ):
67
  """
68
- This method generates a trajectory with partially completed branches,
69
- where the branching comes from taking unilateraly different actions.
70
- The resulting data is used to estimate the updated advantage alignment policy gradient terms.
71
- Let k := nb_sub_steps. Then the number of steps generated is O(Tk), where T is
72
- the maximum trajectory length.
 
 
 
73
  """
74
 
75
  tasks = []
76
  time_step = 0
77
  terminated = False
78
- root = RolloutTreeRootNode(
79
- id=markov_game.get_id(),
80
- crn_id=markov_game.get_crn_id()
81
- )
82
  previous_node = root
83
 
84
  while not terminated:
 
1
+ """
2
+ File: mllm/markov_games/alternative_actions_runner.py
3
+ Summary: Generates rollout branches by replaying trajectories with unilateral action changes.
4
+ """
5
+
6
  import asyncio
7
  import copy
8
  import json
 
21
  AgentId = str
22
 
23
 
 
24
  async def run_with_unilateral_alt_action(
25
  markov_game: MarkovGame,
26
  agent_id: AgentId,
 
29
  max_depth: int,
30
  ):
31
  """
32
+ Roll out a counterfactual branch where ``agent_id`` deviates unilaterally.
33
+
34
+ Starting from ``branch_node`` (which already contains the main trajectory),
35
+ we replay the simulation with the deviating agent's action while freezing
36
+ all other agents/actions, then continue for ``max_depth`` steps.
37
  """
38
 
39
  # Generate alternative action and take a step
 
73
  branch_only_on_new_round: bool = False,
74
  ):
75
  """
76
+ Generate a rollout tree containing the main path plus unilateral deviation branches.
77
+
78
+ For each timestep we:
79
+ 1. Cache agent actions without side effects.
80
+ 2. Advance the main trajectory.
81
+ 3. Spawn ``nb_alternative_actions`` asynchronous deviations per agent,
82
+ each replaying up to ``max_depth`` steps from the cached pre-action state.
83
+ The resulting branches feed advantage-alignment estimators.
84
  """
85
 
86
  tasks = []
87
  time_step = 0
88
  terminated = False
89
+ root = RolloutTreeRootNode(id=markov_game.get_id(), crn_id=markov_game.get_crn_id())
 
 
 
90
  previous_node = root
91
 
92
  while not terminated:
src_code_for_reproducibility/markov_games/group_timesteps.py CHANGED
@@ -1,6 +1,8 @@
1
  """
2
- This module contains the logic for grouping time steps.
 
3
  """
 
4
  import copy
5
  from typing import Callable
6
 
@@ -84,25 +86,6 @@ def group_time_steps(
84
  raise Exception(
85
  "Grouping timesteps by round is not supported for branching trajectories yet."
86
  )
87
- # Special recursive case for branches
88
- # if isinstance(current_node, RolloutTreeBranchNode):
89
- # branches = {}
90
- # for agent_id, branch_nodes in current_node.branches.items():
91
- # branch_group_nodes = []
92
- # for branch_node in branch_nodes:
93
- # branch_group_node = group_time_steps_rec(
94
- # current_node=branch_node,
95
- # group_time_step=group_time_step,
96
- # accumulation_step_logs=copy.deepcopy(accumulation_step_logs))
97
- # branch_group_nodes.append(branch_group_node)
98
- # branches[agent_id] = branch_group_nodes
99
-
100
- # main_child_group_node = group_time_steps_rec(
101
- # current_node=current_node.main_child,
102
- # group_time_step=group_time_step,
103
- # accumulation_step_logs=copy.deepcopy(accumulation_step_logs))
104
-
105
- # return RolloutTreeBranchNode(main_child=main_child_group_node, branches=branches)
106
 
107
  # Accumulate
108
  accumulation_step_logs.append(current_node.step_log)
 
1
  """
2
+ File: mllm/markov_games/group_timesteps.py
3
+ Summary: Provides timestep-grouping utilities for rollout trees and training.
4
  """
5
+
6
  import copy
7
  from typing import Callable
8
 
 
86
  raise Exception(
87
  "Grouping timesteps by round is not supported for branching trajectories yet."
88
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  # Accumulate
91
  accumulation_step_logs.append(current_node.step_log)
src_code_for_reproducibility/markov_games/linear_runner.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import asyncio
2
  import json
3
  import os.path
@@ -10,7 +15,14 @@ async def LinearRunner(
10
  markov_game: MarkovGame, output_folder: str
11
  ) -> RolloutTreeRootNode:
12
  """
13
- This method generates a trajectory without branching.
 
 
 
 
 
 
 
14
  """
15
  time_step = 0
16
  terminated = False
 
1
+ """
2
+ File: mllm/markov_games/linear_runner.py
3
+ Summary: Simulates a single unbranched Markov-game rollout and records it.
4
+ """
5
+
6
  import asyncio
7
  import json
8
  import os.path
 
15
  markov_game: MarkovGame, output_folder: str
16
  ) -> RolloutTreeRootNode:
17
  """
18
+ Generate a single main-path rollout (no branching) for the provided Markov game.
19
+
20
+ Parameters
21
+ ----------
22
+ markov_game:
23
+ Initialized ``MarkovGame`` with agents + simulation ready to step.
24
+ output_folder:
25
+ Unused placeholder in the legacy API (kept for compatibility).
26
  """
27
  time_step = 0
28
  terminated = False
src_code_for_reproducibility/markov_games/markov_game.py CHANGED
@@ -1,18 +1,8 @@
1
  """
2
- This class unifies a simulation, and the agents acting in it (see `simulation.py` & `agent.py`).
3
- In a MarkovGame step,
4
- 1) each agent takes an action,
5
- 2) the state transitions with respect to these actions,
6
- 3) all relevant data of the step is appended to the historical data list
7
-
8
- In order to perform 3), the agents and the simulation are expected, at each time step,
9
- to return a log of the state transition (from their perspective).
10
- For instance, the Simulation might send rewards and the agents might send prompting contexts to be used later to generate the training data.
11
- A different approach would be to simply have the agents keep their data private and log it upon completion of a trajectory.
12
- The approach we use here centralizes the data gathering aspect,
13
- making it easy to create sub-trajectories (in the `runners` defined in `runners.py`) descriptions that
14
- only log information for step transitions occuring after the branching out.
15
  """
 
16
  import asyncio
17
  import copy
18
  import json
@@ -31,6 +21,8 @@ AgentId = str
31
 
32
  @dataclass
33
  class AgentAndActionSafeCopy:
 
 
34
  action: Any
35
  action_info: AgentActLog
36
  agent_after_action: type[Agent]
@@ -45,12 +37,18 @@ class MarkovGame(object):
45
  crn_id: int,
46
  ):
47
  """
48
- Args:
49
- agents:
50
- output_path:
51
- Path where the step infos are saved.
52
- simulation:
53
- Simulation object. Example: IPDSimulation
 
 
 
 
 
 
54
  """
55
  self.agents = agents
56
  self.agent_ids = self.agents.keys()
@@ -131,7 +129,7 @@ class MarkovGame(object):
131
 
132
  async def set_action_of_agent(self, agent_id: AgentId):
133
  """
134
- TOWRITE
135
  """
136
  agent = self.agents[agent_id]
137
  obs = self.simulation.get_obs_agent(agent_id)
@@ -141,7 +139,7 @@ class MarkovGame(object):
141
 
142
  async def set_actions(self):
143
  """
144
- TOWRITE
145
  """
146
  # background_tasks = set()
147
  tasks = []
@@ -152,16 +150,27 @@ class MarkovGame(object):
152
 
153
  def take_simulation_step(self):
154
  """
155
- TOWRITE
156
  """
157
  terminated, self.simulation_step_log = self.simulation.step(self.actions)
158
  return terminated
159
 
160
  def get_step_log(self) -> StepLog:
161
  """
162
- TOWRITE
163
- TODO: assert actions and simulation have taken step
164
  """
 
 
 
 
 
 
 
 
 
 
 
 
165
  step_log = StepLog(
166
  simulation_step_log=self.simulation_step_log,
167
  action_logs=self.agent_step_logs,
@@ -170,7 +179,7 @@ class MarkovGame(object):
170
 
171
  async def step(self) -> Tuple[bool, StepLog]:
172
  """
173
- TOWRITE
174
  """
175
  await self.set_actions()
176
  terminated = self.take_simulation_step()
@@ -179,7 +188,7 @@ class MarkovGame(object):
179
 
180
  def get_safe_copy(self):
181
  """
182
- TOWRITE
183
  """
184
 
185
  new_markov_game = copy.copy(self)
 
1
  """
2
+ File: mllm/markov_games/markov_game.py
3
+ Summary: Defines the MarkovGame base class plus shared simulation interfaces.
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
+
6
  import asyncio
7
  import copy
8
  import json
 
21
 
22
  @dataclass
23
  class AgentAndActionSafeCopy:
24
+ """Snapshot of an agent, its action, and metadata used for branch replay."""
25
+
26
  action: Any
27
  action_info: AgentActLog
28
  agent_after_action: type[Agent]
 
37
  crn_id: int,
38
  ):
39
  """
40
+ Initialize the Markov game wrapper.
41
+
42
+ Parameters
43
+ ----------
44
+ id:
45
+ Unique rollout identifier (logged into rollout trees).
46
+ agents:
47
+ Mapping of agent_id -> Agent instance.
48
+ simulation:
49
+ Environment implementing the ``Simulation`` interface (IPD, TAS, etc.).
50
+ crn_id:
51
+ Identifier for the common random number stream used by this rollout.
52
  """
53
  self.agents = agents
54
  self.agent_ids = self.agents.keys()
 
129
 
130
  async def set_action_of_agent(self, agent_id: AgentId):
131
  """
132
+ Query a single agent for its next action and store the result locally.
133
  """
134
  agent = self.agents[agent_id]
135
  obs = self.simulation.get_obs_agent(agent_id)
 
139
 
140
  async def set_actions(self):
141
  """
142
+ Query every agent concurrently and populate the cached actions/logs.
143
  """
144
  # background_tasks = set()
145
  tasks = []
 
150
 
151
  def take_simulation_step(self):
152
  """
153
+ Advance the simulation by one step using the cached actions.
154
  """
155
  terminated, self.simulation_step_log = self.simulation.step(self.actions)
156
  return terminated
157
 
158
  def get_step_log(self) -> StepLog:
159
  """
160
+ Package the most recent simulation step and agent logs into a StepLog.
 
161
  """
162
+ if self.simulation_step_log is None:
163
+ raise RuntimeError(
164
+ "Simulation step log is empty; call take_simulation_step() first."
165
+ )
166
+ missing_logs = [
167
+ agent_id for agent_id, log in self.agent_step_logs.items() if log is None
168
+ ]
169
+ if missing_logs:
170
+ raise RuntimeError(
171
+ f"Agent action logs missing for: {', '.join(missing_logs)}. "
172
+ "Ensure set_actions() ran before requesting the step log."
173
+ )
174
  step_log = StepLog(
175
  simulation_step_log=self.simulation_step_log,
176
  action_logs=self.agent_step_logs,
 
179
 
180
  async def step(self) -> Tuple[bool, StepLog]:
181
  """
182
+ Convenience step that collects actions, advances the simulation, and returns the log.
183
  """
184
  await self.set_actions()
185
  terminated = self.take_simulation_step()
 
188
 
189
  def get_safe_copy(self):
190
  """
191
+ Create a shallow copy of the game with deep-copied agents/simulation for branching.
192
  """
193
 
194
  new_markov_game = copy.copy(self)
src_code_for_reproducibility/markov_games/mg_utils.py CHANGED
@@ -1,9 +1,18 @@
 
 
 
 
 
1
  import asyncio
2
  import copy
3
  from collections.abc import Callable
4
  from dataclasses import dataclass
5
 
6
  from mllm.markov_games.ipd.ipd_agent import IPDAgent
 
 
 
 
7
  from mllm.markov_games.ipd.ipd_simulation import IPD
8
  from mllm.markov_games.markov_game import MarkovGame
9
  from mllm.markov_games.negotiation.dond_agent import DealNoDealAgent
@@ -12,17 +21,10 @@ from mllm.markov_games.negotiation.nego_hard_coded_policies import (
12
  HardCodedNegoGreedyPolicy,
13
  HardCodedNegoWelfareMaximizingPolicy,
14
  )
15
- from mllm.markov_games.ipd.Ipd_hard_coded_agents import AlwaysCooperateIPDAgent, AlwaysDefectIPDAgent
16
  from mllm.markov_games.negotiation.no_press_nego_agent import NoPressAgent
17
  from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressSimulation
18
- from mllm.markov_games.negotiation.tas_agent import TrustAndSplitAgent
19
  from mllm.markov_games.negotiation.tas_rps_agent import TrustAndSplitRPSAgent
20
  from mllm.markov_games.negotiation.tas_rps_simulation import TrustAndSplitRPSSimulation
21
- from mllm.markov_games.negotiation.tas_simple_agent import TrustAndSplitSimpleAgent
22
- from mllm.markov_games.negotiation.tas_simple_simulation import (
23
- TrustAndSplitSimpleSimulation,
24
- )
25
- from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitSimulation
26
  from mllm.markov_games.rollout_tree import (
27
  AgentActLog,
28
  RolloutTreeBranchNode,
@@ -37,6 +39,8 @@ AgentId = str
37
 
38
  @dataclass
39
  class AgentConfig:
 
 
40
  agent_id: str
41
  agent_name: str
42
  agent_class_name: str
@@ -46,6 +50,8 @@ class AgentConfig:
46
 
47
  @dataclass
48
  class MarkovGameConfig:
 
 
49
  id: int
50
  seed: int
51
  simulation_class_name: str
@@ -57,7 +63,9 @@ def init_markov_game_components(
57
  config: MarkovGameConfig, policies: dict[str, Callable[[list[dict]], str]]
58
  ):
59
  """
60
- TOWRITE
 
 
61
  """
62
  agents = {}
63
  agent_names = []
 
1
+ """
2
+ File: mllm/markov_games/mg_utils.py
3
+ Summary: Holds miscellaneous helpers shared across Markov-game modules.
4
+ """
5
+
6
  import asyncio
7
  import copy
8
  from collections.abc import Callable
9
  from dataclasses import dataclass
10
 
11
  from mllm.markov_games.ipd.ipd_agent import IPDAgent
12
+ from mllm.markov_games.ipd.Ipd_hard_coded_agents import (
13
+ AlwaysCooperateIPDAgent,
14
+ AlwaysDefectIPDAgent,
15
+ )
16
  from mllm.markov_games.ipd.ipd_simulation import IPD
17
  from mllm.markov_games.markov_game import MarkovGame
18
  from mllm.markov_games.negotiation.dond_agent import DealNoDealAgent
 
21
  HardCodedNegoGreedyPolicy,
22
  HardCodedNegoWelfareMaximizingPolicy,
23
  )
 
24
  from mllm.markov_games.negotiation.no_press_nego_agent import NoPressAgent
25
  from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressSimulation
 
26
  from mllm.markov_games.negotiation.tas_rps_agent import TrustAndSplitRPSAgent
27
  from mllm.markov_games.negotiation.tas_rps_simulation import TrustAndSplitRPSSimulation
 
 
 
 
 
28
  from mllm.markov_games.rollout_tree import (
29
  AgentActLog,
30
  RolloutTreeBranchNode,
 
39
 
40
  @dataclass
41
  class AgentConfig:
42
+ """Configuration blob describing one agent in a Markov game spec."""
43
+
44
  agent_id: str
45
  agent_name: str
46
  agent_class_name: str
 
50
 
51
  @dataclass
52
  class MarkovGameConfig:
53
+ """Top-level config that ties together simulation settings and agent configs."""
54
+
55
  id: int
56
  seed: int
57
  simulation_class_name: str
 
63
  config: MarkovGameConfig, policies: dict[str, Callable[[list[dict]], str]]
64
  ):
65
  """
66
+ Materialize Agents and the Simulation described by ``config`` and return a MarkovGame.
67
+
68
+ `policies` is a mapping of policy_id -> callable retrieved from the hosting trainer.
69
  """
70
  agents = {}
71
  agent_names = []
src_code_for_reproducibility/markov_games/negotiation/dond_simulation.py CHANGED
@@ -1,30 +1,45 @@
 
 
 
 
 
1
  import copy
2
  from dataclasses import dataclass
3
  from typing import Any, Dict, List, Tuple
4
 
5
  from numpy.random import default_rng
6
 
 
 
 
 
 
 
7
  from mllm.markov_games.rollout_tree import SimulationStepLog
8
- from mllm.markov_games.negotiation.nego_simulation import Split, NegotiationState, NegotiationObs, NegotiationSimulation
9
  from mllm.utils.get_coagent_id import get_coagent_id
10
 
11
-
12
  AgentId = str
13
 
14
 
15
  @dataclass
16
  class DealNoDealState(NegotiationState):
 
 
17
  item_types: List[str]
18
  values: Dict[AgentId, Dict[str, int]]
19
 
 
20
  @dataclass
21
  class DealNoDealObs(NegotiationObs):
 
 
22
  my_values: Dict[str, int]
23
  item_types: List[str]
24
  previous_values_coagent: Dict[str, int] | None
25
 
26
 
27
  def random_partition_integer(rng, total: int, parts: int) -> List[int]:
 
28
  if parts <= 0:
29
  return []
30
  if total <= 0:
@@ -37,7 +52,9 @@ def random_partition_integer(rng, total: int, parts: int) -> List[int]:
37
  prev = c
38
  return vals
39
 
 
40
  class DealNoDealSimulation(NegotiationSimulation):
 
41
 
42
  def __init__(
43
  self,
@@ -75,7 +92,9 @@ class DealNoDealSimulation(NegotiationSimulation):
75
  if ok1 and ok2:
76
  return {self.agent_ids[0]: a, self.agent_ids[1]: b}
77
 
78
- def _is_valid_allocation(self, allocation: Dict[str, int], stock: Dict[str, int]) -> bool:
 
 
79
  for t in self.item_types:
80
  v = allocation.get(t)
81
  if v is None:
@@ -85,16 +104,18 @@ class DealNoDealSimulation(NegotiationSimulation):
85
  if v < 0 or v > int(stock.get(t, 0)):
86
  return False
87
  return True
88
-
89
  def set_new_round_of_variant(self):
90
  # Keep same values, resample stock
91
  self.state.quantities = self._sample_stock()
92
 
93
- def get_info_of_variant(self, state: NegotiationState, actions: Dict[AgentId, Any]) -> Dict[str, Any]:
 
 
94
  return {
95
  "quantities": copy.deepcopy(state.quantities),
96
  "values": copy.deepcopy(state.values),
97
- 'splits': copy.deepcopy(state.splits),
98
  }
99
 
100
  def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]:
@@ -105,11 +126,15 @@ class DealNoDealSimulation(NegotiationSimulation):
105
  split_b = splits[self.agent_ids[1]].items_given_to_self
106
  rewards = {self.agent_ids[0]: 0, self.agent_ids[1]: 0}
107
  for t in self.item_types:
108
- # If not complementary, return 0!
109
  if not split_a[t] + split_b[t] == self.state.quantities[t]:
110
  return {self.agent_ids[0]: 0, self.agent_ids[1]: 0}
111
- rewards[self.agent_ids[0]] += split_a[t] * self.state.values[self.agent_ids[0]][t]
112
- rewards[self.agent_ids[1]] += split_b[t] * self.state.values[self.agent_ids[1]][t]
 
 
 
 
113
  return rewards
114
 
115
  def get_obs(self):
@@ -149,5 +174,3 @@ class DealNoDealSimulation(NegotiationSimulation):
149
  item_types=list(self.item_types),
150
  )
151
  return self.get_obs()
152
-
153
-
 
1
+ """
2
+ File: mllm/markov_games/negotiation/dond_simulation.py
3
+ Summary: Simulates Deal-or-No-Deal negotiation games and logs rollouts.
4
+ """
5
+
6
  import copy
7
  from dataclasses import dataclass
8
  from typing import Any, Dict, List, Tuple
9
 
10
  from numpy.random import default_rng
11
 
12
+ from mllm.markov_games.negotiation.nego_simulation import (
13
+ NegotiationObs,
14
+ NegotiationSimulation,
15
+ NegotiationState,
16
+ Split,
17
+ )
18
  from mllm.markov_games.rollout_tree import SimulationStepLog
 
19
  from mllm.utils.get_coagent_id import get_coagent_id
20
 
 
21
  AgentId = str
22
 
23
 
24
  @dataclass
25
  class DealNoDealState(NegotiationState):
26
+ """NegotiationState with per-agent value tables and item taxonomy."""
27
+
28
  item_types: List[str]
29
  values: Dict[AgentId, Dict[str, int]]
30
 
31
+
32
  @dataclass
33
  class DealNoDealObs(NegotiationObs):
34
+ """Observation that reveals own values and (lagged) opponent values."""
35
+
36
  my_values: Dict[str, int]
37
  item_types: List[str]
38
  previous_values_coagent: Dict[str, int] | None
39
 
40
 
41
  def random_partition_integer(rng, total: int, parts: int) -> List[int]:
42
+ """Sample non-negative integers summing to ``total`` across ``parts`` buckets."""
43
  if parts <= 0:
44
  return []
45
  if total <= 0:
 
52
  prev = c
53
  return vals
54
 
55
+
56
  class DealNoDealSimulation(NegotiationSimulation):
57
+ """NegotiationSimulation variant implementing the Rubinstein-style Deal-or-No-Deal."""
58
 
59
  def __init__(
60
  self,
 
92
  if ok1 and ok2:
93
  return {self.agent_ids[0]: a, self.agent_ids[1]: b}
94
 
95
+ def _is_valid_allocation(
96
+ self, allocation: Dict[str, int], stock: Dict[str, int]
97
+ ) -> bool:
98
  for t in self.item_types:
99
  v = allocation.get(t)
100
  if v is None:
 
104
  if v < 0 or v > int(stock.get(t, 0)):
105
  return False
106
  return True
107
+
108
  def set_new_round_of_variant(self):
109
  # Keep same values, resample stock
110
  self.state.quantities = self._sample_stock()
111
 
112
+ def get_info_of_variant(
113
+ self, state: NegotiationState, actions: Dict[AgentId, Any]
114
+ ) -> Dict[str, Any]:
115
  return {
116
  "quantities": copy.deepcopy(state.quantities),
117
  "values": copy.deepcopy(state.values),
118
+ "splits": copy.deepcopy(state.splits),
119
  }
120
 
121
  def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]:
 
126
  split_b = splits[self.agent_ids[1]].items_given_to_self
127
  rewards = {self.agent_ids[0]: 0, self.agent_ids[1]: 0}
128
  for t in self.item_types:
129
+ # If not complementary, return 0!
130
  if not split_a[t] + split_b[t] == self.state.quantities[t]:
131
  return {self.agent_ids[0]: 0, self.agent_ids[1]: 0}
132
+ rewards[self.agent_ids[0]] += (
133
+ split_a[t] * self.state.values[self.agent_ids[0]][t]
134
+ )
135
+ rewards[self.agent_ids[1]] += (
136
+ split_b[t] * self.state.values[self.agent_ids[1]][t]
137
+ )
138
  return rewards
139
 
140
  def get_obs(self):
 
174
  item_types=list(self.item_types),
175
  )
176
  return self.get_obs()
 
 
src_code_for_reproducibility/markov_games/negotiation/nego_simulation.py CHANGED
@@ -1,7 +1,8 @@
1
  """
2
- Negotiation simulation environment
3
- other agent is set at the start of every round. Even though current agent changes over message turns in a round.
4
  """
 
5
  import copy
6
  from abc import abstractmethod
7
  from dataclasses import dataclass
@@ -18,16 +19,22 @@ AgentId = str
18
 
19
  @dataclass
20
  class Split:
 
 
21
  items_given_to_self: Dict[str, int]
22
 
23
 
24
  @dataclass
25
  class Message:
 
 
26
  message: str
27
 
28
 
29
  @dataclass # gets extended by variants
30
  class NegotiationState:
 
 
31
  round_nb: int
32
  last_message: str
33
  current_agent: AgentId
@@ -44,6 +51,8 @@ class NegotiationState:
44
 
45
  @dataclass # gets extended by variants
46
  class NegotiationObs:
 
 
47
  round_nb: int
48
  last_message: str
49
  quota_messages_per_agent_per_round: int
@@ -134,12 +143,14 @@ class NegotiationSimulation(Simulation):
134
 
135
  @abstractmethod
136
  def set_new_round_of_variant(self):
 
137
  pass
138
 
139
  @abstractmethod
140
  def get_info_of_variant(
141
  self, state: NegotiationState, actions: Dict[AgentId, Any]
142
  ) -> Dict[str, Any]:
 
143
  pass
144
 
145
  def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]:
@@ -190,7 +201,7 @@ class NegotiationSimulation(Simulation):
190
  is_last_timestep_in_round = True
191
  done = self.state.round_nb >= self.nb_of_rounds
192
 
193
- # Message phase
194
  elif isinstance(action, Message):
195
  self.state.last_message = action.message
196
  self.state.nb_messages_sent[current_agent] += 1
 
1
  """
2
+ File: mllm/markov_games/negotiation/nego_simulation.py
3
+ Summary: Simulation harness for general negotiation environments.
4
  """
5
+
6
  import copy
7
  from abc import abstractmethod
8
  from dataclasses import dataclass
 
19
 
20
  @dataclass
21
  class Split:
22
+ """Structured proposal describing how many units of each item an agent keeps."""
23
+
24
  items_given_to_self: Dict[str, int]
25
 
26
 
27
  @dataclass
28
  class Message:
29
+ """Single chat utterance exchanged during the negotiation phase."""
30
+
31
  message: str
32
 
33
 
34
  @dataclass # gets extended by variants
35
  class NegotiationState:
36
+ """Full simulator state snapshot shared by all negotiation variants."""
37
+
38
  round_nb: int
39
  last_message: str
40
  current_agent: AgentId
 
51
 
52
  @dataclass # gets extended by variants
53
  class NegotiationObs:
54
+ """Observation presented to agents each turn (base fields; variants extend)."""
55
+
56
  round_nb: int
57
  last_message: str
58
  quota_messages_per_agent_per_round: int
 
143
 
144
  @abstractmethod
145
  def set_new_round_of_variant(self):
146
+ """Variant hook: sample new private values / stock before each round."""
147
  pass
148
 
149
  @abstractmethod
150
  def get_info_of_variant(
151
  self, state: NegotiationState, actions: Dict[AgentId, Any]
152
  ) -> Dict[str, Any]:
153
+ """Variant hook: populate SimulationStepLog.info with custom diagnostics."""
154
  pass
155
 
156
  def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]:
 
201
  is_last_timestep_in_round = True
202
  done = self.state.round_nb >= self.nb_of_rounds
203
 
204
+ # Message phase: roll the conversation forward a single turn.
205
  elif isinstance(action, Message):
206
  self.state.last_message = action.message
207
  self.state.nb_messages_sent[current_agent] += 1
src_code_for_reproducibility/markov_games/negotiation/tas_agent.py CHANGED
@@ -1,9 +1,16 @@
 
 
 
 
 
1
  from mllm.markov_games.negotiation.nego_agent import NegotiationAgent
2
  from mllm.markov_games.negotiation.nego_simulation import Split
3
  from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitObs
4
 
5
 
6
  class TrustAndSplitAgent(NegotiationAgent):
 
 
7
  def __init__(self, num_message_chars, *args, **kwargs):
8
  self.num_message_chars = num_message_chars
9
  super().__init__(*args, **kwargs)
@@ -58,12 +65,14 @@ class TrustAndSplitAgent(NegotiationAgent):
58
  self.send_message_prompt = f"Send your message now in <message>...</message> (<={self.num_message_chars} chars)."
59
 
60
  def get_message_regex(self, observation: TrustAndSplitObs) -> str:
 
61
  return rf"<message>[\s\S]{{0,{self.num_message_chars}}}</message>"
62
 
63
  # def get_message_regex(self, observation: TrustAndSplitObs) -> str:
64
  # return rf"(?s).{{0,{self.num_message_chars}}}"
65
 
66
  def get_split_regex(self, observation: TrustAndSplitObs) -> str:
 
67
  items = list(observation.quantities.keys())
68
  # Accept both singular and plural forms
69
  item_pattern = "|".join(
@@ -75,6 +84,7 @@ class TrustAndSplitAgent(NegotiationAgent):
75
  def get_split_action(
76
  self, policy_output: str, observation: TrustAndSplitObs
77
  ) -> Split:
 
78
  items = list(observation.quantities.keys())
79
  import re as _re
80
 
 
1
+ """
2
+ File: mllm/markov_games/negotiation/tas_agent.py
3
+ Summary: Agent implementation for Take-and-Split negotiations.
4
+ """
5
+
6
  from mllm.markov_games.negotiation.nego_agent import NegotiationAgent
7
  from mllm.markov_games.negotiation.nego_simulation import Split
8
  from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitObs
9
 
10
 
11
  class TrustAndSplitAgent(NegotiationAgent):
12
+ """Prompt/template wrapper for the classic multi-item Take-and-Split benchmark."""
13
+
14
  def __init__(self, num_message_chars, *args, **kwargs):
15
  self.num_message_chars = num_message_chars
16
  super().__init__(*args, **kwargs)
 
65
  self.send_message_prompt = f"Send your message now in <message>...</message> (<={self.num_message_chars} chars)."
66
 
67
  def get_message_regex(self, observation: TrustAndSplitObs) -> str:
68
+ """Constrain chat to bounded XML tags for stable parsing."""
69
  return rf"<message>[\s\S]{{0,{self.num_message_chars}}}</message>"
70
 
71
  # def get_message_regex(self, observation: TrustAndSplitObs) -> str:
72
  # return rf"(?s).{{0,{self.num_message_chars}}}"
73
 
74
  def get_split_regex(self, observation: TrustAndSplitObs) -> str:
75
+ """Allow natural-language item names while still returning machine-parsable XML."""
76
  items = list(observation.quantities.keys())
77
  # Accept both singular and plural forms
78
  item_pattern = "|".join(
 
84
  def get_split_action(
85
  self, policy_output: str, observation: TrustAndSplitObs
86
  ) -> Split:
87
+ """Convert human-readable allocation text back into canonical item IDs."""
88
  items = list(observation.quantities.keys())
89
  import re as _re
90
 
src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import copy
2
  from collections.abc import Callable
3
  from dataclasses import dataclass
@@ -15,6 +20,8 @@ from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn
15
 
16
 
17
  class TrustAndSplitRPSAgent(NegotiationAgent):
 
 
18
  def __init__(
19
  self,
20
  num_message_chars: int,
@@ -88,6 +95,7 @@ class TrustAndSplitRPSAgent(NegotiationAgent):
88
  self.send_message_prompt = f"Send your message now in <message>...</message> (<={self.num_message_chars} chars)."
89
 
90
  def get_message_regex(self, observation: TrustAndSplitRPSObs) -> str:
 
91
  if self.message_start_end_format:
92
  return (
93
  rf"<<message_start>>[\s\S]{{0,{self.num_message_chars}}}<<message_end>>"
@@ -96,6 +104,7 @@ class TrustAndSplitRPSAgent(NegotiationAgent):
96
  return rf"<message>[\s\S]{{0,{self.num_message_chars}}}</message>"
97
 
98
  def get_split_regex(self, observation: TrustAndSplitRPSObs) -> str:
 
99
  if self.proposal_start_end_format:
100
  return r"<<proposal_start>> ?(10|[0-9]) ?<<proposal_end>>"
101
  else:
@@ -104,6 +113,7 @@ class TrustAndSplitRPSAgent(NegotiationAgent):
104
  def get_split_action(
105
  self, policy_output: str, observation: TrustAndSplitRPSObs
106
  ) -> Split:
 
107
  import re as _re
108
 
109
  if self.proposal_start_end_format:
 
1
+ """
2
+ File: mllm/markov_games/negotiation/tas_rps_agent.py
3
+ Summary: Agent logic for TAS Rock-Paper-Scissors blended game.
4
+ """
5
+
6
  import copy
7
  from collections.abc import Callable
8
  from dataclasses import dataclass
 
20
 
21
 
22
  class TrustAndSplitRPSAgent(NegotiationAgent):
23
+ """NegotiationAgent that reasons about hidden hands before submitting TAS splits."""
24
+
25
  def __init__(
26
  self,
27
  num_message_chars: int,
 
95
  self.send_message_prompt = f"Send your message now in <message>...</message> (<={self.num_message_chars} chars)."
96
 
97
  def get_message_regex(self, observation: TrustAndSplitRPSObs) -> str:
98
+ """Switch between <message>...</message> and <<message_start>> formats on demand."""
99
  if self.message_start_end_format:
100
  return (
101
  rf"<<message_start>>[\s\S]{{0,{self.num_message_chars}}}<<message_end>>"
 
104
  return rf"<message>[\s\S]{{0,{self.num_message_chars}}}</message>"
105
 
106
  def get_split_regex(self, observation: TrustAndSplitRPSObs) -> str:
107
+ """Force single-number proposals inside whichever tag style the config selected."""
108
  if self.proposal_start_end_format:
109
  return r"<<proposal_start>> ?(10|[0-9]) ?<<proposal_end>>"
110
  else:
 
113
  def get_split_action(
114
  self, policy_output: str, observation: TrustAndSplitRPSObs
115
  ) -> Split:
116
+ """Parse the proposal tag (or raw integer fallback) into a Split."""
117
  import re as _re
118
 
119
  if self.proposal_start_end_format:
src_code_for_reproducibility/markov_games/rollout_tree.py CHANGED
@@ -1,5 +1,6 @@
1
  """
2
- TODO: add parent to nodes so that some verification can be done. For instance, to ensure that node reward keys match the parent node.
 
3
  """
4
 
5
  from __future__ import annotations
@@ -18,11 +19,15 @@ AgentId = str
18
 
19
 
20
  class SimulationStepLog(BaseModel):
 
 
21
  rewards: dict[AgentId, float]
22
  info: Any = None
23
 
24
 
25
  class AgentActLog(BaseModel):
 
 
26
  chat_turns: list[ChatTurn] | None
27
  info: Any = None
28
 
@@ -55,6 +60,8 @@ class StepLog(BaseModel):
55
 
56
 
57
  class RolloutTreeNode(BaseModel):
 
 
58
  step_log: StepLog
59
  time_step: int
60
  child: RolloutTreeNode | RolloutTreeBranchNode | None = None
@@ -70,6 +77,8 @@ class RolloutTreeBranchNode(BaseModel):
70
 
71
 
72
  class RolloutTreeRootNode(BaseModel):
 
 
73
  id: int
74
  crn_id: int # ID of the rng used to generate this rollout tree
75
  child: RolloutTreeNode | RolloutTreeBranchNode | None = None
 
1
  """
2
+ File: mllm/markov_games/rollout_tree.py
3
+ Summary: Defines rollout tree data structures and serialization helpers.
4
  """
5
 
6
  from __future__ import annotations
 
19
 
20
 
21
  class SimulationStepLog(BaseModel):
22
+ """Minimal snapshot of environment-side rewards and auxiliary info."""
23
+
24
  rewards: dict[AgentId, float]
25
  info: Any = None
26
 
27
 
28
  class AgentActLog(BaseModel):
29
+ """LLM-side provenance for an action (chat turns + metadata)."""
30
+
31
  chat_turns: list[ChatTurn] | None
32
  info: Any = None
33
 
 
60
 
61
 
62
  class RolloutTreeNode(BaseModel):
63
+ """Single timestep of the main trajectory (or a branch) plus linkage."""
64
+
65
  step_log: StepLog
66
  time_step: int
67
  child: RolloutTreeNode | RolloutTreeBranchNode | None = None
 
77
 
78
 
79
  class RolloutTreeRootNode(BaseModel):
80
+ """Entry point for serialized rollouts (main path plus optional branches)."""
81
+
82
  id: int
83
  crn_id: int # ID of the rng used to generate this rollout tree
84
  child: RolloutTreeNode | RolloutTreeBranchNode | None = None
src_code_for_reproducibility/markov_games/run_markov_games.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import asyncio
2
  from collections.abc import Callable
3
  from dataclasses import dataclass
@@ -14,6 +19,12 @@ async def run_markov_games(
14
  output_folder: str,
15
  markov_games: list[MarkovGame],
16
  ) -> list[RolloutTreeRootNode]:
 
 
 
 
 
 
17
  tasks = []
18
  for mg in markov_games:
19
  tasks.append(
 
1
+ """
2
+ File: mllm/markov_games/run_markov_games.py
3
+ Summary: CLI entry point for running configured Markov-game experiments.
4
+ """
5
+
6
  import asyncio
7
  from collections.abc import Callable
8
  from dataclasses import dataclass
 
19
  output_folder: str,
20
  markov_games: list[MarkovGame],
21
  ) -> list[RolloutTreeRootNode]:
22
+ """
23
+ Kick off multiple Markov game rollouts concurrently and return their trees.
24
+
25
+ Parameters mirror the Hydra configs (runner callable + kwargs) so callers can
26
+ choose ``LinearRunner``, ``AlternativeActionsRunner`` or future variants.
27
+ """
28
  tasks = []
29
  for mg in markov_games:
30
  tasks.append(
src_code_for_reproducibility/markov_games/simulation.py CHANGED
@@ -1,8 +1,6 @@
1
  """
2
- A Simulation is the environment of a Markov Game.
3
- The Simulation is not responsible for properly checking / formatting the responses of LLM's.
4
- This is the job of the `Agent` class.
5
- Simulations expect clean actions, and are defined similarly to `gymnasium` environments, except that they are adapted for the Multi-agent setting.
6
  """
7
 
8
  from abc import ABC, abstractmethod
@@ -22,59 +20,68 @@ class Simulation(ABC):
22
  @abstractmethod
23
  def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]:
24
  """
25
- Returns terminated, info
 
 
 
 
 
 
 
26
  """
27
  raise NotImplementedError
28
 
29
  def get_obs(self):
30
- """Returns all agent observations in dict
31
-
32
- Returns:
33
- observations
34
- """
35
  raise NotImplementedError
36
 
37
  def get_obs_agent(self, agent_id):
38
- """Returns observation for agent_id"""
39
  raise NotImplementedError
40
 
41
  def get_obs_size(self):
42
- """Returns the shape of the observation"""
43
  raise NotImplementedError
44
 
45
  def get_state(self):
 
46
  raise NotImplementedError
47
 
48
  def get_state_size(self):
49
- """Returns the shape of the state"""
50
  raise NotImplementedError
51
 
52
  def get_avail_actions(self):
 
53
  raise NotImplementedError
54
 
55
  def get_avail_agent_actions(self, agent_id):
56
- """Returns the available actions for agent_id"""
57
  raise NotImplementedError
58
 
59
  def get_total_actions(self):
60
- """Returns the total number of actions an agent could ever take"""
61
- # TODO: This is only suitable for a discrete 1 dimensional action space for each agent
 
 
62
  raise NotImplementedError
63
 
64
  def get_safe_copy(self):
65
  """
66
- Return copy of the agent object that is decorrelated from the original object.
67
  """
68
  raise NotImplementedError
69
 
70
  def reset(self):
71
- """Returns initial observations and states"""
72
  raise NotImplementedError
73
 
74
  def render(self):
 
75
  raise NotImplementedError
76
 
77
  def close(self):
 
78
  raise NotImplementedError
79
 
80
  # def seed(self):
 
1
  """
2
+ File: mllm/markov_games/simulation.py
3
+ Summary: Core simulation loop utilities and step logging for Markov games.
 
 
4
  """
5
 
6
  from abc import ABC, abstractmethod
 
20
  @abstractmethod
21
  def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]:
22
  """
23
+ Advance the environment by one logical tick using ``actions``.
24
+
25
+ Returns
26
+ -------
27
+ terminated: bool
28
+ Whether the episode has finished.
29
+ SimulationStepLog
30
+ Reward/info bundle describing this transition.
31
  """
32
  raise NotImplementedError
33
 
34
  def get_obs(self):
35
+ """Return a dict mapping agent_id -> observation for *all* agents."""
 
 
 
 
36
  raise NotImplementedError
37
 
38
  def get_obs_agent(self, agent_id):
39
+ """Return the observation for a single agent."""
40
  raise NotImplementedError
41
 
42
  def get_obs_size(self):
43
+ """Describe the observation tensor shape (useful for critic heads)."""
44
  raise NotImplementedError
45
 
46
  def get_state(self):
47
+ """Return the privileged simulator state if available."""
48
  raise NotImplementedError
49
 
50
  def get_state_size(self):
51
+ """Describe the state tensor shape."""
52
  raise NotImplementedError
53
 
54
  def get_avail_actions(self):
55
+ """Return the global action mask/tensor if the space is discrete."""
56
  raise NotImplementedError
57
 
58
  def get_avail_agent_actions(self, agent_id):
59
+ """Return the available action mask for a given agent."""
60
  raise NotImplementedError
61
 
62
  def get_total_actions(self):
63
+ """Returns the total number of actions an agent could ever take.
64
+
65
+ Implementations currently assume a discrete, one-dimensional action space per agent.
66
+ """
67
  raise NotImplementedError
68
 
69
  def get_safe_copy(self):
70
  """
71
+ Return copy of the simulator that shares no mutable state with the original.
72
  """
73
  raise NotImplementedError
74
 
75
  def reset(self):
76
+ """Reset to the initial state and return the starting observations."""
77
  raise NotImplementedError
78
 
79
  def render(self):
80
+ """Optional human-facing visualization."""
81
  raise NotImplementedError
82
 
83
  def close(self):
84
+ """Release any owned resources (files, processes, etc.)."""
85
  raise NotImplementedError
86
 
87
  # def seed(self):
src_code_for_reproducibility/markov_games/statistics_runner.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import gc
@@ -36,17 +41,20 @@ def _iterate_main_nodes(root: RolloutTreeRootNode) -> Iterator[RolloutTreeNode]:
36
  def iterate_main_simulation_logs(
37
  root: RolloutTreeRootNode,
38
  ) -> Iterator[SimulationStepLog]:
 
39
  for node in _iterate_main_nodes(root):
40
  yield node.step_log.simulation_step_log
41
 
42
 
43
  def stream_rollout_files(iteration_folder: Path) -> Iterator[Path]:
 
44
  for p in iteration_folder.rglob("*.rt.pkl"):
45
  if p.is_file():
46
  yield p
47
 
48
 
49
  def load_root(path: Path) -> RolloutTreeRootNode:
 
50
  with open(path, "rb") as f:
51
  data = pickle.load(f)
52
  return RolloutTreeRootNode.model_validate(data)
@@ -54,6 +62,8 @@ def load_root(path: Path) -> RolloutTreeRootNode:
54
 
55
  @dataclass
56
  class StatRecord:
 
 
57
  mgid: int
58
  crn_id: Optional[int]
59
  iteration: str
 
1
+ """
2
+ File: mllm/markov_games/statistics_runner.py
3
+ Summary: Executes multiple rollouts to compute experiment statistics.
4
+ """
5
+
6
  from __future__ import annotations
7
 
8
  import gc
 
41
  def iterate_main_simulation_logs(
42
  root: RolloutTreeRootNode,
43
  ) -> Iterator[SimulationStepLog]:
44
+ """Yield ``SimulationStepLog`` objects along the main (non-branch) path."""
45
  for node in _iterate_main_nodes(root):
46
  yield node.step_log.simulation_step_log
47
 
48
 
49
  def stream_rollout_files(iteration_folder: Path) -> Iterator[Path]:
50
+ """Iterate over every ``*.rt.pkl`` file under an iteration directory."""
51
  for p in iteration_folder.rglob("*.rt.pkl"):
52
  if p.is_file():
53
  yield p
54
 
55
 
56
  def load_root(path: Path) -> RolloutTreeRootNode:
57
+ """Load and validate a rollout tree from disk."""
58
  with open(path, "rb") as f:
59
  data = pickle.load(f)
60
  return RolloutTreeRootNode.model_validate(data)
 
62
 
63
  @dataclass
64
  class StatRecord:
65
+ """Convenience container for serialized stat rows."""
66
+
67
  mgid: int
68
  crn_id: Optional[int]
69
  iteration: str
src_code_for_reproducibility/models/__init__.py CHANGED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ File: mllm/models/__init__.py
3
+ Summary: Exports model-layer utilities from the models package.
4
+ """
src_code_for_reproducibility/models/__pycache__/human_policy.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/models/__pycache__/human_policy.cpython-312.pyc and b/src_code_for_reproducibility/models/__pycache__/human_policy.cpython-312.pyc differ
 
src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-312.pyc and b/src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-312.pyc differ
 
src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc and b/src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc differ
 
src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-312.pyc and b/src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-312.pyc differ
 
src_code_for_reproducibility/models/__pycache__/large_language_model_api.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/models/__pycache__/large_language_model_api.cpython-312.pyc and b/src_code_for_reproducibility/models/__pycache__/large_language_model_api.cpython-312.pyc differ
 
src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-312.pyc and b/src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-312.pyc differ
 
src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc and b/src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc differ
 
src_code_for_reproducibility/models/adapter_training_wrapper.py CHANGED
@@ -1,11 +1,14 @@
1
- import torch
2
- import torch.nn as nn
 
 
 
3
  import logging
4
  from typing import Union
5
- from peft import (
6
- LoraConfig,
7
- get_peft_model,
8
- )
9
 
10
  logger = logging.getLogger(__name__)
11
 
@@ -18,13 +21,14 @@ class AdapterWrapper(nn.Module):
18
  • exposes only the parameters that should be trained for that adapter
19
  (plus whatever extra modules you name).
20
  """
 
21
  def __init__(
22
  self,
23
  shared_llm: nn.Module,
24
  adapter_id: str,
25
  lora_config: dict,
26
  path: Union[str, None] = None,
27
- ):
28
  super().__init__()
29
  self.shared_llm = shared_llm
30
  self.adapter_id = adapter_id
@@ -47,7 +51,9 @@ class AdapterWrapper(nn.Module):
47
  adapter_name=adapter_id,
48
  )
49
  loaded_from = path
50
- except Exception as exc: # noqa: BLE001 - want to log any load failure context
 
 
51
  logger.warning(
52
  f"Adapter '{adapter_id}': failed to load from '{path}': {exc}"
53
  )
 
1
+ """
2
+ File: mllm/models/adapter_training_wrapper.py
3
+ Summary: Wraps a shared LLM with adapter-specific PEFT handling for training.
4
+ """
5
+
6
  import logging
7
  from typing import Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from peft import LoraConfig, get_peft_model
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
21
  • exposes only the parameters that should be trained for that adapter
22
  (plus whatever extra modules you name).
23
  """
24
+
25
  def __init__(
26
  self,
27
  shared_llm: nn.Module,
28
  adapter_id: str,
29
  lora_config: dict,
30
  path: Union[str, None] = None,
31
+ ):
32
  super().__init__()
33
  self.shared_llm = shared_llm
34
  self.adapter_id = adapter_id
 
51
  adapter_name=adapter_id,
52
  )
53
  loaded_from = path
54
+ except (
55
+ Exception
56
+ ) as exc: # noqa: BLE001 - want to log any load failure context
57
  logger.warning(
58
  f"Adapter '{adapter_id}': failed to load from '{path}': {exc}"
59
  )
src_code_for_reproducibility/models/human_policy.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import asyncio
2
  import os
3
  import re
 
1
+ """
2
+ File: mllm/models/human_policy.py
3
+ Summary: Implements an interactive human-in-the-loop policy for experiments.
4
+ """
5
+
6
  import asyncio
7
  import os
8
  import re
src_code_for_reproducibility/models/inference_backend.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  from abc import ABC, abstractmethod
2
  from dataclasses import dataclass
3
  from typing import Any, Optional
 
1
+ """
2
+ File: mllm/models/inference_backend.py
3
+ Summary: Declares the inference backend interface and shared dataclasses.
4
+ """
5
+
6
  from abc import ABC, abstractmethod
7
  from dataclasses import dataclass
8
  from typing import Any, Optional
src_code_for_reproducibility/models/inference_backend_dummy.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import asyncio
2
  from typing import Optional
3
 
 
1
+ """
2
+ File: mllm/models/inference_backend_dummy.py
3
+ Summary: Stub inference backend that returns synthetic completions for tests.
4
+ """
5
+
6
  import asyncio
7
  from typing import Optional
8
 
src_code_for_reproducibility/models/inference_backend_vllm.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import asyncio
2
  import re
3
  from typing import Optional
@@ -23,19 +28,12 @@ class VLLMAsyncBackend(LLMInferenceBackend):
23
  sampling_params: dict = {},
24
  ):
25
  self.model_name = model_name
26
- # self.adapter_paths = adapter_paths or {}
27
- # self.current_adapter = None
28
- # self.vllm_adapter_ids = {
29
- # adapter_id: generate_short_id() for adapter_id in adapter_paths.keys()
30
- # }
31
  self.vllm_adapter_ids = {}
32
  ea = dict(model=model_name, **engine_init_kwargs)
33
- # ea["enable_lora"] = True
34
- # ea["max_loras"] = len(self.vllm_adapter_ids)
35
- # ea["enable_sleep_mode"] = True
36
  self.engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**ea))
37
 
38
  self.sampling_params = sampling_params
 
39
 
40
  def prepare_adapter(
41
  self,
@@ -43,7 +41,6 @@ class VLLMAsyncBackend(LLMInferenceBackend):
43
  adapter_path: Optional[str],
44
  weights_got_updated: bool,
45
  ) -> None:
46
- # self.current_adapter = adapter_id
47
  if weights_got_updated:
48
  self.vllm_adapter_ids[adapter_id] = generate_short_id()
49
  self.current_lora_request = LoRARequest(
@@ -96,9 +93,6 @@ class VLLMAsyncBackend(LLMInferenceBackend):
96
  ]
97
  log_probs = torch.tensor(log_probs)
98
  out_token_ids = torch.tensor(out_token_ids, dtype=torch.long)
99
- # for out_token_id, logprob_dict in zip(out_token_ids, res.outputs[0].logprobs):
100
- # if logprob_dict[out_token_id].logprob < -1:
101
- # print(f"High negative logprob {logprob_dict[out_token_id].logprob} for {logprob_dict}")
102
  content = raw_text
103
  reasoning_content = None
104
 
 
1
+ """
2
+ File: mllm/models/inference_backend_vllm.py
3
+ Summary: Connects to in-process vLLM instances for batched generation.
4
+ """
5
+
6
  import asyncio
7
  import re
8
  from typing import Optional
 
28
  sampling_params: dict = {},
29
  ):
30
  self.model_name = model_name
 
 
 
 
 
31
  self.vllm_adapter_ids = {}
32
  ea = dict(model=model_name, **engine_init_kwargs)
 
 
 
33
  self.engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**ea))
34
 
35
  self.sampling_params = sampling_params
36
+ self.tokenizer = tokenizer
37
 
38
  def prepare_adapter(
39
  self,
 
41
  adapter_path: Optional[str],
42
  weights_got_updated: bool,
43
  ) -> None:
 
44
  if weights_got_updated:
45
  self.vllm_adapter_ids[adapter_id] = generate_short_id()
46
  self.current_lora_request = LoRARequest(
 
93
  ]
94
  log_probs = torch.tensor(log_probs)
95
  out_token_ids = torch.tensor(out_token_ids, dtype=torch.long)
 
 
 
96
  content = raw_text
97
  reasoning_content = None
98
 
src_code_for_reproducibility/models/large_language_model_api.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import asyncio
@@ -13,7 +18,7 @@ from openai import AsyncOpenAI, OpenAIError
13
  from mllm.markov_games.rollout_tree import ChatTurn
14
  from mllm.models.inference_backend import LLMInferenceOutput
15
 
16
- # TODO: Get this automatically from OpenAI
17
  reasoning_models = [
18
  "gpt-5-nano",
19
  "gpt-5-mini",
@@ -119,9 +124,7 @@ class LargeLanguageModelOpenAI:
119
  agent_id: str,
120
  regex: Optional[str] = None,
121
  ) -> LLMInferenceOutput:
122
- # Remove any non-role/content keys from the prompt else openai will error
123
-
124
- # TODO:
125
  prompt = [{"role": p.role, "content": p.content} for p in state]
126
 
127
  # if self.sleep_between_requests:
 
1
+ """
2
+ File: mllm/models/large_language_model_api.py
3
+ Summary: Implements API-based large-language-model inference adapters.
4
+ """
5
+
6
  from __future__ import annotations
7
 
8
  import asyncio
 
18
  from mllm.markov_games.rollout_tree import ChatTurn
19
  from mllm.models.inference_backend import LLMInferenceOutput
20
 
21
+ # Static list copied from the public OpenAI docs until a discovery endpoint is exposed.
22
  reasoning_models = [
23
  "gpt-5-nano",
24
  "gpt-5-mini",
 
124
  agent_id: str,
125
  regex: Optional[str] = None,
126
  ) -> LLMInferenceOutput:
127
+ # Remove any non-role/content keys from the prompt else openai will error.
 
 
128
  prompt = [{"role": p.role, "content": p.content} for p in state]
129
 
130
  # if self.sleep_between_requests:
src_code_for_reproducibility/models/large_language_model_local.py CHANGED
@@ -1,5 +1,6 @@
1
  """
2
- TODO: Figure out how to tweak SGlang not to go OOM when batch size is 32. See https://github.com/sgl-project/sglang/issues/6309.
 
3
  """
4
 
5
  import logging
@@ -16,23 +17,14 @@ import httpx
16
  import requests
17
  import torch
18
  import torch.nn as nn
19
-
20
- # from sglang.utils import (
21
- # launch_server_cmd,
22
- # print_highlight,
23
- # terminate_process,
24
- # wait_for_server,
25
- # )
26
  from torch.optim import SGD, Adam, AdamW, RMSprop
27
  from transformers import AutoModelForCausalLM, AutoTokenizer
28
- from trl import AutoModelForCausalLMWithValueHead
29
 
30
  from mllm.chat_utils.apply_template import chat_turns_to_token_ids
31
  from mllm.markov_games.rollout_tree import ChatTurn
32
  from mllm.models.adapter_training_wrapper import AdapterWrapper
33
  from mllm.models.inference_backend import LLMInferenceOutput
34
  from mllm.models.inference_backend_dummy import DummyInferenceBackend
35
- from mllm.models.inference_backend_sglang import SGLangOfflineBackend
36
  from mllm.models.inference_backend_vllm import VLLMAsyncBackend
37
 
38
  logger = logging.getLogger(__name__)
@@ -44,7 +36,7 @@ PolicyID = str
44
 
45
  class LeanLocalLLM:
46
  """
47
- TOWRITE
48
  """
49
 
50
  def __init__(
@@ -55,7 +47,7 @@ class LeanLocalLLM:
55
  hf_kwargs: dict = {},
56
  adapter_configs: dict = {},
57
  output_directory: str = "./models/",
58
- inference_backend: Literal["vllm", "sglang", "dummy"] = "vllm",
59
  inference_backend_sampling_params: dict = {},
60
  inference_backend_init_kwargs: dict = {},
61
  initial_adapter_paths: dict[str, str] | None = None,
@@ -180,15 +172,7 @@ class LeanLocalLLM:
180
  # Init inference inference_backend
181
  # ---------------------------------------------------------
182
 
183
- if inference_backend == "sglang":
184
- self.inference_backend = SGLangOfflineBackend(
185
- model_name=self.model_name,
186
- save_path=self.save_path,
187
- adapter_paths=self.adapter_paths,
188
- tokenizer=self.tokenizer,
189
- kwargs=inference_backend_init_kwargs,
190
- )
191
- elif inference_backend == "vllm":
192
  self.inference_backend = VLLMAsyncBackend(
193
  model_name=self.model_name,
194
  # adapter_paths=self.adapter_paths,
@@ -206,7 +190,7 @@ class LeanLocalLLM:
206
 
207
  def get_inference_policies(self) -> dict[PolicyID, Callable]:
208
  """
209
- TOWRITE
210
  """
211
  policies = {}
212
  for adapter_id in self.adapter_ids:
@@ -242,8 +226,8 @@ class LeanLocalLLM:
242
  """
243
  Returns wrappers over the adapters which allows them be
244
  interfaced like regular PyTorch models.
245
- # TODO: create the adapter wrappers here
246
- See adapter_wrapper.py
247
  """
248
  trainable_objects = {an: self.hf_adapters[an] for an in self.adapter_ids}
249
  return trainable_objects
@@ -297,13 +281,11 @@ class LeanLocalLLM:
297
  tokenizer=self.tokenizer,
298
  enable_thinking=self.enable_thinking,
299
  )
300
- # print(f"context is {self.tokenizer.decode(context_token_ids)}")
301
  policy_output = await self.inference_backend.generate(
302
  input_token_ids=context_token_ids.tolist(),
303
  extract_thinking=(self.max_thinking_characters > 0),
304
  regex=current_regex,
305
  )
306
- # print(f"generated: {self.tokenizer.decode(policy_output.out_token_ids)}")
307
  if (
308
  pattern is None
309
  or (pattern.fullmatch(policy_output.content))
@@ -347,11 +329,6 @@ class LeanLocalLLM:
347
  for adapter_id in self.past_agent_adapter_ids:
348
  self.weights_got_updated[adapter_id] = True
349
 
350
- # import random
351
- # self.save_path = self.save_path + str(random.randint(1,500))
352
- # print(f"Save path: {self.save_path}")
353
- # self.adapter_paths = {adapter_id:os.path.join(self.save_path, adapter_id) for adapter_id in self.adapter_ids}
354
-
355
  adapter_id = self.adapter_ids[0]
356
  self.hf_adapters[adapter_id].save_pretrained(self.save_path)
357
 
 
1
  """
2
+ File: mllm/models/large_language_model_local.py
3
+ Summary: Provides a local large language model wrapper over inference backends.
4
  """
5
 
6
  import logging
 
17
  import requests
18
  import torch
19
  import torch.nn as nn
 
 
 
 
 
 
 
20
  from torch.optim import SGD, Adam, AdamW, RMSprop
21
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
22
 
23
  from mllm.chat_utils.apply_template import chat_turns_to_token_ids
24
  from mllm.markov_games.rollout_tree import ChatTurn
25
  from mllm.models.adapter_training_wrapper import AdapterWrapper
26
  from mllm.models.inference_backend import LLMInferenceOutput
27
  from mllm.models.inference_backend_dummy import DummyInferenceBackend
 
28
  from mllm.models.inference_backend_vllm import VLLMAsyncBackend
29
 
30
  logger = logging.getLogger(__name__)
 
36
 
37
  class LeanLocalLLM:
38
  """
39
+ Wrapper that manages local HuggingFace models, adapters, and inference backends.
40
  """
41
 
42
  def __init__(
 
47
  hf_kwargs: dict = {},
48
  adapter_configs: dict = {},
49
  output_directory: str = "./models/",
50
+ inference_backend: Literal["vllm", "dummy"] = "vllm",
51
  inference_backend_sampling_params: dict = {},
52
  inference_backend_init_kwargs: dict = {},
53
  initial_adapter_paths: dict[str, str] | None = None,
 
172
  # Init inference inference_backend
173
  # ---------------------------------------------------------
174
 
175
+ if inference_backend == "vllm":
 
 
 
 
 
 
 
 
176
  self.inference_backend = VLLMAsyncBackend(
177
  model_name=self.model_name,
178
  # adapter_paths=self.adapter_paths,
 
190
 
191
  def get_inference_policies(self) -> dict[PolicyID, Callable]:
192
  """
193
+ Build async policy callables keyed by adapter id for inference-only usage.
194
  """
195
  policies = {}
196
  for adapter_id in self.adapter_ids:
 
226
  """
227
  Returns wrappers over the adapters which allows them be
228
  interfaced like regular PyTorch models.
229
+ AdapterWrapper lives in adapter_wrapper.py; the huggingface modules already wrap
230
+ parameters here, so we surface them directly until an extra shim is required.
231
  """
232
  trainable_objects = {an: self.hf_adapters[an] for an in self.adapter_ids}
233
  return trainable_objects
 
281
  tokenizer=self.tokenizer,
282
  enable_thinking=self.enable_thinking,
283
  )
 
284
  policy_output = await self.inference_backend.generate(
285
  input_token_ids=context_token_ids.tolist(),
286
  extract_thinking=(self.max_thinking_characters > 0),
287
  regex=current_regex,
288
  )
 
289
  if (
290
  pattern is None
291
  or (pattern.fullmatch(policy_output.content))
 
329
  for adapter_id in self.past_agent_adapter_ids:
330
  self.weights_got_updated[adapter_id] = True
331
 
 
 
 
 
 
332
  adapter_id = self.adapter_ids[0]
333
  self.hf_adapters[adapter_id].save_pretrained(self.save_path)
334
 
src_code_for_reproducibility/models/scalar_critic.py CHANGED
@@ -1,6 +1,13 @@
1
- import torch, torch.nn as nn, torch.optim as optim
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
3
  from peft import LoraConfig, get_peft_model
 
4
 
5
  from mllm.models.adapter_training_wrapper import AdapterWrapper
6
 
@@ -11,18 +18,16 @@ class ScalarCritic(nn.Module):
11
  V_φ(s) = wᵀ h_last + b
12
  Only LoRA adapters (inside critic_adapter) and the value head are trainable.
13
  """
 
14
  def __init__(self, critic_adapter: AdapterWrapper):
15
  super().__init__()
16
  self.critic_adapter = critic_adapter
17
  hidden_size = self.critic_adapter.shared_llm.config.hidden_size
18
  self.value_head = nn.Linear(hidden_size, 1).to(
19
- dtype=critic_adapter.dtype,
20
- device=critic_adapter.device)
21
 
22
- def forward(self,
23
- input_ids,
24
- attention_mask=None,
25
- **kwargs):
26
  # AdapterWrapper activates its own adapter internally
27
  outputs = self.critic_adapter(
28
  input_ids=input_ids,
@@ -30,7 +35,7 @@ class ScalarCritic(nn.Module):
30
  output_hidden_states=True,
31
  **kwargs,
32
  )
33
- h_last = outputs.hidden_states[-1] # (B, S, H)
34
  values = self.value_head(h_last).squeeze(-1) # (B, S)
35
  return values
36
 
 
1
+ """
2
+ File: mllm/models/scalar_critic.py
3
+ Summary: Defines a scalar critic network and helper utilities.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
  from peft import LoraConfig, get_peft_model
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
 
12
  from mllm.models.adapter_training_wrapper import AdapterWrapper
13
 
 
18
  V_φ(s) = wᵀ h_last + b
19
  Only LoRA adapters (inside critic_adapter) and the value head are trainable.
20
  """
21
+
22
  def __init__(self, critic_adapter: AdapterWrapper):
23
  super().__init__()
24
  self.critic_adapter = critic_adapter
25
  hidden_size = self.critic_adapter.shared_llm.config.hidden_size
26
  self.value_head = nn.Linear(hidden_size, 1).to(
27
+ dtype=critic_adapter.dtype, device=critic_adapter.device
28
+ )
29
 
30
+ def forward(self, input_ids, attention_mask=None, **kwargs):
 
 
 
31
  # AdapterWrapper activates its own adapter internally
32
  outputs = self.critic_adapter(
33
  input_ids=input_ids,
 
35
  output_hidden_states=True,
36
  **kwargs,
37
  )
38
+ h_last = outputs.hidden_states[-1] # (B, S, H)
39
  values = self.value_head(h_last).squeeze(-1) # (B, S)
40
  return values
41
 
src_code_for_reproducibility/training/__init__.py CHANGED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ File: mllm/training/__init__.py
3
+ Summary: Exposes training submodules through the package namespace.
4
+ """
src_code_for_reproducibility/training/__pycache__/__init__.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/training/__pycache__/__init__.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/__init__.cpython-312.pyc differ
 
src_code_for_reproducibility/training/__pycache__/annealing_methods.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/training/__pycache__/annealing_methods.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/annealing_methods.cpython-312.pyc differ
 
src_code_for_reproducibility/training/__pycache__/credit_methods.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/training/__pycache__/credit_methods.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/credit_methods.cpython-312.pyc differ
 
src_code_for_reproducibility/training/__pycache__/tally_metrics.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/training/__pycache__/tally_metrics.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/tally_metrics.cpython-312.pyc differ
 
src_code_for_reproducibility/training/__pycache__/tally_rollout.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/training/__pycache__/tally_rollout.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/tally_rollout.cpython-312.pyc differ
 
src_code_for_reproducibility/training/__pycache__/tally_tokenwise.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/training/__pycache__/tally_tokenwise.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/tally_tokenwise.cpython-312.pyc differ
 
src_code_for_reproducibility/training/__pycache__/tokenize_chats.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/training/__pycache__/tokenize_chats.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/tokenize_chats.cpython-312.pyc differ
 
src_code_for_reproducibility/training/__pycache__/trainer_ad_align.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/training/__pycache__/trainer_ad_align.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/trainer_ad_align.cpython-312.pyc differ
 
src_code_for_reproducibility/training/__pycache__/trainer_common.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/training/__pycache__/trainer_common.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/trainer_common.cpython-312.pyc differ
 
src_code_for_reproducibility/training/__pycache__/trainer_independent.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/training/__pycache__/trainer_independent.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/trainer_independent.cpython-312.pyc differ
 
src_code_for_reproducibility/training/__pycache__/trainer_sum_rewards.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/training/__pycache__/trainer_sum_rewards.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/trainer_sum_rewards.cpython-312.pyc differ
 
src_code_for_reproducibility/training/__pycache__/training_data_utils.cpython-312.pyc CHANGED
Binary files a/src_code_for_reproducibility/training/__pycache__/training_data_utils.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/training_data_utils.cpython-312.pyc differ
 
src_code_for_reproducibility/training/annealing_methods.py CHANGED
@@ -1,6 +1,20 @@
 
 
 
 
 
1
  import numpy as np
2
 
3
 
4
  def sigmoid_annealing(step: int, temperature: float) -> float:
5
- return 2 / (1 + np.exp(-step / temperature)) - 1
 
6
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: mllm/training/annealing_methods.py
3
+ Summary: Implements annealing schedules used across training loops.
4
+ """
5
+
6
  import numpy as np
7
 
8
 
9
  def sigmoid_annealing(step: int, temperature: float) -> float:
10
+ """
11
+ Smoothly ramp a scalar from 0 → 1 using a temperature-controlled sigmoid.
12
 
13
+ Args:
14
+ step: Current training step or iteration.
15
+ temperature: Controls how sharp the transition is; larger values flatten the curve.
16
+
17
+ Returns:
18
+ Float in [-1, 1] that can be rescaled for annealing schedules.
19
+ """
20
+ return 2 / (1 + np.exp(-step / temperature)) - 1