Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- run.log +0 -0
- src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc +0 -0
- src_code_for_reproducibility/chat_utils/apply_template.py +12 -1
- src_code_for_reproducibility/chat_utils/chat_turn.py +5 -0
- src_code_for_reproducibility/chat_utils/template_specific.py +27 -0
- src_code_for_reproducibility/markov_games/__init__.py +4 -0
- src_code_for_reproducibility/markov_games/agent.py +18 -22
- src_code_for_reproducibility/markov_games/alternative_actions_runner.py +19 -11
- src_code_for_reproducibility/markov_games/group_timesteps.py +3 -20
- src_code_for_reproducibility/markov_games/linear_runner.py +13 -1
- src_code_for_reproducibility/markov_games/markov_game.py +35 -26
- src_code_for_reproducibility/markov_games/mg_utils.py +16 -8
- src_code_for_reproducibility/markov_games/negotiation/dond_simulation.py +34 -11
- src_code_for_reproducibility/markov_games/negotiation/nego_simulation.py +14 -3
- src_code_for_reproducibility/markov_games/negotiation/tas_agent.py +10 -0
- src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py +10 -0
- src_code_for_reproducibility/markov_games/rollout_tree.py +10 -1
- src_code_for_reproducibility/markov_games/run_markov_games.py +11 -0
- src_code_for_reproducibility/markov_games/simulation.py +25 -18
- src_code_for_reproducibility/markov_games/statistics_runner.py +10 -0
- src_code_for_reproducibility/models/__init__.py +4 -0
- src_code_for_reproducibility/models/__pycache__/human_policy.cpython-312.pyc +0 -0
- src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-312.pyc +0 -0
- src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc +0 -0
- src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-312.pyc +0 -0
- src_code_for_reproducibility/models/__pycache__/large_language_model_api.cpython-312.pyc +0 -0
- src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-312.pyc +0 -0
- src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc +0 -0
- src_code_for_reproducibility/models/adapter_training_wrapper.py +14 -8
- src_code_for_reproducibility/models/human_policy.py +5 -0
- src_code_for_reproducibility/models/inference_backend.py +5 -0
- src_code_for_reproducibility/models/inference_backend_dummy.py +5 -0
- src_code_for_reproducibility/models/inference_backend_vllm.py +6 -12
- src_code_for_reproducibility/models/large_language_model_api.py +7 -4
- src_code_for_reproducibility/models/large_language_model_local.py +8 -31
- src_code_for_reproducibility/models/scalar_critic.py +14 -9
- src_code_for_reproducibility/training/__init__.py +4 -0
- src_code_for_reproducibility/training/__pycache__/__init__.cpython-312.pyc +0 -0
- src_code_for_reproducibility/training/__pycache__/annealing_methods.cpython-312.pyc +0 -0
- src_code_for_reproducibility/training/__pycache__/credit_methods.cpython-312.pyc +0 -0
- src_code_for_reproducibility/training/__pycache__/tally_metrics.cpython-312.pyc +0 -0
- src_code_for_reproducibility/training/__pycache__/tally_rollout.cpython-312.pyc +0 -0
- src_code_for_reproducibility/training/__pycache__/tally_tokenwise.cpython-312.pyc +0 -0
- src_code_for_reproducibility/training/__pycache__/tokenize_chats.cpython-312.pyc +0 -0
- src_code_for_reproducibility/training/__pycache__/trainer_ad_align.cpython-312.pyc +0 -0
- src_code_for_reproducibility/training/__pycache__/trainer_common.cpython-312.pyc +0 -0
- src_code_for_reproducibility/training/__pycache__/trainer_independent.cpython-312.pyc +0 -0
- src_code_for_reproducibility/training/__pycache__/trainer_sum_rewards.cpython-312.pyc +0 -0
- src_code_for_reproducibility/training/__pycache__/training_data_utils.cpython-312.pyc +0 -0
- src_code_for_reproducibility/training/annealing_methods.py +15 -1
run.log
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc and b/src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/chat_utils/apply_template.py
CHANGED
|
@@ -1,10 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
|
| 3 |
from mllm.chat_utils.chat_turn import ChatTurn
|
| 4 |
from mllm.chat_utils.template_specific import (
|
|
|
|
| 5 |
custom_llama3_template,
|
| 6 |
custom_qwen2_template,
|
| 7 |
custom_qwen3_template,
|
|
|
|
| 8 |
qwen2_assistant_postfix,
|
| 9 |
qwen3_assistant_postfix,
|
| 10 |
)
|
|
@@ -20,6 +27,8 @@ def get_custom_chat_template(tokenizer) -> str:
|
|
| 20 |
return custom_llama3_template
|
| 21 |
elif "qwen3" in tokenizer.name_or_path.lower():
|
| 22 |
return custom_qwen3_template
|
|
|
|
|
|
|
| 23 |
else:
|
| 24 |
raise ValueError(f"Tokenizer {tokenizer.name_or_path} not supported")
|
| 25 |
|
|
@@ -32,13 +41,15 @@ def get_custom_assistant_postfix(tokenizer) -> torch.Tensor:
|
|
| 32 |
return qwen2_assistant_postfix
|
| 33 |
elif "qwen3" in tokenizer.name_or_path.lower():
|
| 34 |
return qwen3_assistant_postfix
|
|
|
|
|
|
|
| 35 |
return torch.tensor([], dtype=torch.long)
|
| 36 |
|
| 37 |
|
| 38 |
def tokenize_chats(chats: list[ChatTurn], tokenizer, enable_thinking) -> None:
|
| 39 |
"""
|
| 40 |
Set the chat_template_token_ids for each chat turn.
|
| 41 |
-
|
| 42 |
"""
|
| 43 |
custom_template = get_custom_chat_template(tokenizer)
|
| 44 |
custom_assistant_postfix: torch.Tensor = get_custom_assistant_postfix(tokenizer)
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/chat_utils/apply_template.py
|
| 3 |
+
Summary: Applies tokenizer-specific chat templates and stitches chat token IDs.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import torch
|
| 7 |
|
| 8 |
from mllm.chat_utils.chat_turn import ChatTurn
|
| 9 |
from mllm.chat_utils.template_specific import (
|
| 10 |
+
custom_gemma3_template,
|
| 11 |
custom_llama3_template,
|
| 12 |
custom_qwen2_template,
|
| 13 |
custom_qwen3_template,
|
| 14 |
+
gemma3_assistant_postfix,
|
| 15 |
qwen2_assistant_postfix,
|
| 16 |
qwen3_assistant_postfix,
|
| 17 |
)
|
|
|
|
| 27 |
return custom_llama3_template
|
| 28 |
elif "qwen3" in tokenizer.name_or_path.lower():
|
| 29 |
return custom_qwen3_template
|
| 30 |
+
elif "gemma" in tokenizer.name_or_path.lower():
|
| 31 |
+
return custom_gemma3_template
|
| 32 |
else:
|
| 33 |
raise ValueError(f"Tokenizer {tokenizer.name_or_path} not supported")
|
| 34 |
|
|
|
|
| 41 |
return qwen2_assistant_postfix
|
| 42 |
elif "qwen3" in tokenizer.name_or_path.lower():
|
| 43 |
return qwen3_assistant_postfix
|
| 44 |
+
elif "gemma" in tokenizer.name_or_path.lower():
|
| 45 |
+
return gemma3_assistant_postfix
|
| 46 |
return torch.tensor([], dtype=torch.long)
|
| 47 |
|
| 48 |
|
| 49 |
def tokenize_chats(chats: list[ChatTurn], tokenizer, enable_thinking) -> None:
|
| 50 |
"""
|
| 51 |
Set the chat_template_token_ids for each chat turn.
|
| 52 |
+
We rely on tokenizer-side templates because engine-provided cached tokens are not exposed yet.
|
| 53 |
"""
|
| 54 |
custom_template = get_custom_chat_template(tokenizer)
|
| 55 |
custom_assistant_postfix: torch.Tensor = get_custom_assistant_postfix(tokenizer)
|
src_code_for_reproducibility/chat_utils/chat_turn.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import json
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/chat_utils/chat_turn.py
|
| 3 |
+
Summary: Defines the ChatTurn schema plus helpers for serialization and validation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
from __future__ import annotations
|
| 7 |
|
| 8 |
import json
|
src_code_for_reproducibility/chat_utils/template_specific.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import huggingface_hub
|
| 2 |
import torch
|
| 3 |
from transformers import AutoTokenizer
|
|
@@ -25,6 +30,11 @@ qwen3_assistant_postfix = (
|
|
| 25 |
.encode("\n", return_tensors="pt")
|
| 26 |
.flatten()
|
| 27 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
custom_qwen2_template = """
|
| 29 |
{%- if add_system_prompt %}
|
| 30 |
{{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }}
|
|
@@ -85,3 +95,20 @@ custom_qwen3_template = """
|
|
| 85 |
{%- endif %}
|
| 86 |
{%- endif %}
|
| 87 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/chat_utils/template_specific.py
|
| 3 |
+
Summary: Stores chat template variants and assistant postfix tensors per tokenizer.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import huggingface_hub
|
| 7 |
import torch
|
| 8 |
from transformers import AutoTokenizer
|
|
|
|
| 30 |
.encode("\n", return_tensors="pt")
|
| 31 |
.flatten()
|
| 32 |
)
|
| 33 |
+
gemma3_assistant_postfix = (
|
| 34 |
+
AutoTokenizer.from_pretrained("google/gemma-3-4b-it")
|
| 35 |
+
.encode("\n", return_tensors="pt")
|
| 36 |
+
.flatten()
|
| 37 |
+
)
|
| 38 |
custom_qwen2_template = """
|
| 39 |
{%- if add_system_prompt %}
|
| 40 |
{{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }}
|
|
|
|
| 95 |
{%- endif %}
|
| 96 |
{%- endif %}
|
| 97 |
"""
|
| 98 |
+
|
| 99 |
+
custom_gemma3_template = """
|
| 100 |
+
{%- if add_system_prompt %}
|
| 101 |
+
{{- bos_token -}}
|
| 102 |
+
{%- endif %}
|
| 103 |
+
{%- for message in messages -%}
|
| 104 |
+
{%- if message['role'] == 'assistant' -%}
|
| 105 |
+
{%- set role = 'model' -%}
|
| 106 |
+
{%- else -%}
|
| 107 |
+
{%- set role = message['role'] -%}
|
| 108 |
+
{%- endif -%}
|
| 109 |
+
{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}
|
| 110 |
+
{%- endfor -%}
|
| 111 |
+
{%- if add_generation_prompt -%}
|
| 112 |
+
{{ '<start_of_turn>model\n' }}
|
| 113 |
+
{%- endif -%}
|
| 114 |
+
"""
|
src_code_for_reproducibility/markov_games/__init__.py
CHANGED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/markov_games/__init__.py
|
| 3 |
+
Summary: Makes Markov-game subpackages importable from the top-level namespace.
|
| 4 |
+
"""
|
src_code_for_reproducibility/markov_games/agent.py
CHANGED
|
@@ -1,11 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
a direct path from policy to action. For instance, from the observation of the environment,
|
| 5 |
-
a prompt must be created. Then, the outputs of the policy might be incorrect, so a second
|
| 6 |
-
request to the LLM must be sent before the action is well defined. This is why this Agent class exists.
|
| 7 |
-
It acts as a mini environment, bridging the gap between the core simulation and
|
| 8 |
-
the LLM policies.
|
| 9 |
"""
|
| 10 |
|
| 11 |
from abc import ABC, abstractmethod
|
|
@@ -18,6 +13,8 @@ from mllm.markov_games.rollout_tree import AgentActLog
|
|
| 18 |
|
| 19 |
|
| 20 |
class Agent(ABC):
|
|
|
|
|
|
|
| 21 |
@abstractmethod
|
| 22 |
def __init__(
|
| 23 |
self,
|
|
@@ -29,7 +26,10 @@ class Agent(ABC):
|
|
| 29 |
**kwargs,
|
| 30 |
):
|
| 31 |
"""
|
| 32 |
-
Initialize the agent state.
|
|
|
|
|
|
|
|
|
|
| 33 |
"""
|
| 34 |
self.seed = seed
|
| 35 |
self.agent_id = agent_id
|
|
@@ -40,37 +40,33 @@ class Agent(ABC):
|
|
| 40 |
|
| 41 |
async def act(self, observation) -> Tuple[Any, AgentActLog]:
|
| 42 |
"""
|
| 43 |
-
|
| 44 |
-
obtain the action of the agent.
|
| 45 |
|
| 46 |
-
|
| 47 |
-
action
|
| 48 |
-
prompt = self.observation_to_prompt(observation)
|
| 49 |
-
while not self.valid(action):
|
| 50 |
-
output = await self.policy.generate(prompt)
|
| 51 |
-
action = self.policy_output_to_action(output)
|
| 52 |
-
return action
|
| 53 |
-
|
| 54 |
-
Returns:
|
| 55 |
-
action
|
| 56 |
-
step_info
|
| 57 |
"""
|
| 58 |
raise NotImplementedError
|
| 59 |
|
| 60 |
def get_safe_copy(self):
|
| 61 |
"""
|
| 62 |
-
Return
|
|
|
|
|
|
|
| 63 |
"""
|
| 64 |
raise NotImplementedError
|
| 65 |
|
| 66 |
def reset(self):
|
|
|
|
| 67 |
raise NotImplementedError
|
| 68 |
|
| 69 |
def render(self):
|
|
|
|
| 70 |
raise NotImplementedError
|
| 71 |
|
| 72 |
def close(self):
|
|
|
|
| 73 |
raise NotImplementedError
|
| 74 |
|
| 75 |
def get_agent_info(self):
|
|
|
|
| 76 |
raise NotImplementedError
|
|
|
|
| 1 |
"""
|
| 2 |
+
File: mllm/markov_games/agent.py
|
| 3 |
+
Summary: Declares the base Agent interface connecting simulations to policy calls.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from abc import ABC, abstractmethod
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class Agent(ABC):
|
| 16 |
+
"""Abstract policy wrapper that bridges simulations with arbitrary backends."""
|
| 17 |
+
|
| 18 |
@abstractmethod
|
| 19 |
def __init__(
|
| 20 |
self,
|
|
|
|
| 26 |
**kwargs,
|
| 27 |
):
|
| 28 |
"""
|
| 29 |
+
Initialize the agent state and seed its RNG.
|
| 30 |
+
|
| 31 |
+
Subclasses typically store extra handles (tokenizers, inference clients, etc.)
|
| 32 |
+
but they should always call ``super().__init__`` so sampling remains reproducible.
|
| 33 |
"""
|
| 34 |
self.seed = seed
|
| 35 |
self.agent_id = agent_id
|
|
|
|
| 40 |
|
| 41 |
async def act(self, observation) -> Tuple[Any, AgentActLog]:
|
| 42 |
"""
|
| 43 |
+
Produce the next action (and associated chat log) given an environment observation.
|
|
|
|
| 44 |
|
| 45 |
+
Implementations can iterate with rejection sampling, multi-call deliberation, etc.
|
| 46 |
+
Returns both the chosen action and an `AgentActLog` describing how it was produced.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
"""
|
| 48 |
raise NotImplementedError
|
| 49 |
|
| 50 |
def get_safe_copy(self):
|
| 51 |
"""
|
| 52 |
+
Return a deep copy whose future calls do not mutate the original agent.
|
| 53 |
+
|
| 54 |
+
Needed for branch exploration/reruns with alternative actions.
|
| 55 |
"""
|
| 56 |
raise NotImplementedError
|
| 57 |
|
| 58 |
def reset(self):
|
| 59 |
+
"""Reset any internal state between rollouts."""
|
| 60 |
raise NotImplementedError
|
| 61 |
|
| 62 |
def render(self):
|
| 63 |
+
"""Optional human-readable visualization of the agent (CLI/UI)."""
|
| 64 |
raise NotImplementedError
|
| 65 |
|
| 66 |
def close(self):
|
| 67 |
+
"""Release any external resources (network sockets, subprocesses, etc.)."""
|
| 68 |
raise NotImplementedError
|
| 69 |
|
| 70 |
def get_agent_info(self):
|
| 71 |
+
"""Return diagnostic metadata to embed inside rollout logs."""
|
| 72 |
raise NotImplementedError
|
src_code_for_reproducibility/markov_games/alternative_actions_runner.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import copy
|
| 3 |
import json
|
|
@@ -16,7 +21,6 @@ from mllm.markov_games.rollout_tree import (
|
|
| 16 |
AgentId = str
|
| 17 |
|
| 18 |
|
| 19 |
-
|
| 20 |
async def run_with_unilateral_alt_action(
|
| 21 |
markov_game: MarkovGame,
|
| 22 |
agent_id: AgentId,
|
|
@@ -25,7 +29,11 @@ async def run_with_unilateral_alt_action(
|
|
| 25 |
max_depth: int,
|
| 26 |
):
|
| 27 |
"""
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
"""
|
| 30 |
|
| 31 |
# Generate alternative action and take a step
|
|
@@ -65,20 +73,20 @@ async def AlternativeActionsRunner(
|
|
| 65 |
branch_only_on_new_round: bool = False,
|
| 66 |
):
|
| 67 |
"""
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
| 73 |
"""
|
| 74 |
|
| 75 |
tasks = []
|
| 76 |
time_step = 0
|
| 77 |
terminated = False
|
| 78 |
-
root = RolloutTreeRootNode(
|
| 79 |
-
id=markov_game.get_id(),
|
| 80 |
-
crn_id=markov_game.get_crn_id()
|
| 81 |
-
)
|
| 82 |
previous_node = root
|
| 83 |
|
| 84 |
while not terminated:
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/markov_games/alternative_actions_runner.py
|
| 3 |
+
Summary: Generates rollout branches by replaying trajectories with unilateral action changes.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import asyncio
|
| 7 |
import copy
|
| 8 |
import json
|
|
|
|
| 21 |
AgentId = str
|
| 22 |
|
| 23 |
|
|
|
|
| 24 |
async def run_with_unilateral_alt_action(
|
| 25 |
markov_game: MarkovGame,
|
| 26 |
agent_id: AgentId,
|
|
|
|
| 29 |
max_depth: int,
|
| 30 |
):
|
| 31 |
"""
|
| 32 |
+
Roll out a counterfactual branch where ``agent_id`` deviates unilaterally.
|
| 33 |
+
|
| 34 |
+
Starting from ``branch_node`` (which already contains the main trajectory),
|
| 35 |
+
we replay the simulation with the deviating agent's action while freezing
|
| 36 |
+
all other agents/actions, then continue for ``max_depth`` steps.
|
| 37 |
"""
|
| 38 |
|
| 39 |
# Generate alternative action and take a step
|
|
|
|
| 73 |
branch_only_on_new_round: bool = False,
|
| 74 |
):
|
| 75 |
"""
|
| 76 |
+
Generate a rollout tree containing the main path plus unilateral deviation branches.
|
| 77 |
+
|
| 78 |
+
For each timestep we:
|
| 79 |
+
1. Cache agent actions without side effects.
|
| 80 |
+
2. Advance the main trajectory.
|
| 81 |
+
3. Spawn ``nb_alternative_actions`` asynchronous deviations per agent,
|
| 82 |
+
each replaying up to ``max_depth`` steps from the cached pre-action state.
|
| 83 |
+
The resulting branches feed advantage-alignment estimators.
|
| 84 |
"""
|
| 85 |
|
| 86 |
tasks = []
|
| 87 |
time_step = 0
|
| 88 |
terminated = False
|
| 89 |
+
root = RolloutTreeRootNode(id=markov_game.get_id(), crn_id=markov_game.get_crn_id())
|
|
|
|
|
|
|
|
|
|
| 90 |
previous_node = root
|
| 91 |
|
| 92 |
while not terminated:
|
src_code_for_reproducibility/markov_games/group_timesteps.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
"""
|
| 2 |
-
|
|
|
|
| 3 |
"""
|
|
|
|
| 4 |
import copy
|
| 5 |
from typing import Callable
|
| 6 |
|
|
@@ -84,25 +86,6 @@ def group_time_steps(
|
|
| 84 |
raise Exception(
|
| 85 |
"Grouping timesteps by round is not supported for branching trajectories yet."
|
| 86 |
)
|
| 87 |
-
# Special recursive case for branches
|
| 88 |
-
# if isinstance(current_node, RolloutTreeBranchNode):
|
| 89 |
-
# branches = {}
|
| 90 |
-
# for agent_id, branch_nodes in current_node.branches.items():
|
| 91 |
-
# branch_group_nodes = []
|
| 92 |
-
# for branch_node in branch_nodes:
|
| 93 |
-
# branch_group_node = group_time_steps_rec(
|
| 94 |
-
# current_node=branch_node,
|
| 95 |
-
# group_time_step=group_time_step,
|
| 96 |
-
# accumulation_step_logs=copy.deepcopy(accumulation_step_logs))
|
| 97 |
-
# branch_group_nodes.append(branch_group_node)
|
| 98 |
-
# branches[agent_id] = branch_group_nodes
|
| 99 |
-
|
| 100 |
-
# main_child_group_node = group_time_steps_rec(
|
| 101 |
-
# current_node=current_node.main_child,
|
| 102 |
-
# group_time_step=group_time_step,
|
| 103 |
-
# accumulation_step_logs=copy.deepcopy(accumulation_step_logs))
|
| 104 |
-
|
| 105 |
-
# return RolloutTreeBranchNode(main_child=main_child_group_node, branches=branches)
|
| 106 |
|
| 107 |
# Accumulate
|
| 108 |
accumulation_step_logs.append(current_node.step_log)
|
|
|
|
| 1 |
"""
|
| 2 |
+
File: mllm/markov_games/group_timesteps.py
|
| 3 |
+
Summary: Provides timestep-grouping utilities for rollout trees and training.
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import copy
|
| 7 |
from typing import Callable
|
| 8 |
|
|
|
|
| 86 |
raise Exception(
|
| 87 |
"Grouping timesteps by round is not supported for branching trajectories yet."
|
| 88 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
# Accumulate
|
| 91 |
accumulation_step_logs.append(current_node.step_log)
|
src_code_for_reproducibility/markov_games/linear_runner.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import json
|
| 3 |
import os.path
|
|
@@ -10,7 +15,14 @@ async def LinearRunner(
|
|
| 10 |
markov_game: MarkovGame, output_folder: str
|
| 11 |
) -> RolloutTreeRootNode:
|
| 12 |
"""
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
time_step = 0
|
| 16 |
terminated = False
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/markov_games/linear_runner.py
|
| 3 |
+
Summary: Simulates a single unbranched Markov-game rollout and records it.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import asyncio
|
| 7 |
import json
|
| 8 |
import os.path
|
|
|
|
| 15 |
markov_game: MarkovGame, output_folder: str
|
| 16 |
) -> RolloutTreeRootNode:
|
| 17 |
"""
|
| 18 |
+
Generate a single main-path rollout (no branching) for the provided Markov game.
|
| 19 |
+
|
| 20 |
+
Parameters
|
| 21 |
+
----------
|
| 22 |
+
markov_game:
|
| 23 |
+
Initialized ``MarkovGame`` with agents + simulation ready to step.
|
| 24 |
+
output_folder:
|
| 25 |
+
Unused placeholder in the legacy API (kept for compatibility).
|
| 26 |
"""
|
| 27 |
time_step = 0
|
| 28 |
terminated = False
|
src_code_for_reproducibility/markov_games/markov_game.py
CHANGED
|
@@ -1,18 +1,8 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
1) each agent takes an action,
|
| 5 |
-
2) the state transitions with respect to these actions,
|
| 6 |
-
3) all relevant data of the step is appended to the historical data list
|
| 7 |
-
|
| 8 |
-
In order to perform 3), the agents and the simulation are expected, at each time step,
|
| 9 |
-
to return a log of the state transition (from their perspective).
|
| 10 |
-
For instance, the Simulation might send rewards and the agents might send prompting contexts to be used later to generate the training data.
|
| 11 |
-
A different approach would be to simply have the agents keep their data private and log it upon completion of a trajectory.
|
| 12 |
-
The approach we use here centralizes the data gathering aspect,
|
| 13 |
-
making it easy to create sub-trajectories (in the `runners` defined in `runners.py`) descriptions that
|
| 14 |
-
only log information for step transitions occuring after the branching out.
|
| 15 |
"""
|
|
|
|
| 16 |
import asyncio
|
| 17 |
import copy
|
| 18 |
import json
|
|
@@ -31,6 +21,8 @@ AgentId = str
|
|
| 31 |
|
| 32 |
@dataclass
|
| 33 |
class AgentAndActionSafeCopy:
|
|
|
|
|
|
|
| 34 |
action: Any
|
| 35 |
action_info: AgentActLog
|
| 36 |
agent_after_action: type[Agent]
|
|
@@ -45,12 +37,18 @@ class MarkovGame(object):
|
|
| 45 |
crn_id: int,
|
| 46 |
):
|
| 47 |
"""
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
"""
|
| 55 |
self.agents = agents
|
| 56 |
self.agent_ids = self.agents.keys()
|
|
@@ -131,7 +129,7 @@ class MarkovGame(object):
|
|
| 131 |
|
| 132 |
async def set_action_of_agent(self, agent_id: AgentId):
|
| 133 |
"""
|
| 134 |
-
|
| 135 |
"""
|
| 136 |
agent = self.agents[agent_id]
|
| 137 |
obs = self.simulation.get_obs_agent(agent_id)
|
|
@@ -141,7 +139,7 @@ class MarkovGame(object):
|
|
| 141 |
|
| 142 |
async def set_actions(self):
|
| 143 |
"""
|
| 144 |
-
|
| 145 |
"""
|
| 146 |
# background_tasks = set()
|
| 147 |
tasks = []
|
|
@@ -152,16 +150,27 @@ class MarkovGame(object):
|
|
| 152 |
|
| 153 |
def take_simulation_step(self):
|
| 154 |
"""
|
| 155 |
-
|
| 156 |
"""
|
| 157 |
terminated, self.simulation_step_log = self.simulation.step(self.actions)
|
| 158 |
return terminated
|
| 159 |
|
| 160 |
def get_step_log(self) -> StepLog:
|
| 161 |
"""
|
| 162 |
-
|
| 163 |
-
TODO: assert actions and simulation have taken step
|
| 164 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
step_log = StepLog(
|
| 166 |
simulation_step_log=self.simulation_step_log,
|
| 167 |
action_logs=self.agent_step_logs,
|
|
@@ -170,7 +179,7 @@ class MarkovGame(object):
|
|
| 170 |
|
| 171 |
async def step(self) -> Tuple[bool, StepLog]:
|
| 172 |
"""
|
| 173 |
-
|
| 174 |
"""
|
| 175 |
await self.set_actions()
|
| 176 |
terminated = self.take_simulation_step()
|
|
@@ -179,7 +188,7 @@ class MarkovGame(object):
|
|
| 179 |
|
| 180 |
def get_safe_copy(self):
|
| 181 |
"""
|
| 182 |
-
|
| 183 |
"""
|
| 184 |
|
| 185 |
new_markov_game = copy.copy(self)
|
|
|
|
| 1 |
"""
|
| 2 |
+
File: mllm/markov_games/markov_game.py
|
| 3 |
+
Summary: Defines the MarkovGame base class plus shared simulation interfaces.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import asyncio
|
| 7 |
import copy
|
| 8 |
import json
|
|
|
|
| 21 |
|
| 22 |
@dataclass
|
| 23 |
class AgentAndActionSafeCopy:
|
| 24 |
+
"""Snapshot of an agent, its action, and metadata used for branch replay."""
|
| 25 |
+
|
| 26 |
action: Any
|
| 27 |
action_info: AgentActLog
|
| 28 |
agent_after_action: type[Agent]
|
|
|
|
| 37 |
crn_id: int,
|
| 38 |
):
|
| 39 |
"""
|
| 40 |
+
Initialize the Markov game wrapper.
|
| 41 |
+
|
| 42 |
+
Parameters
|
| 43 |
+
----------
|
| 44 |
+
id:
|
| 45 |
+
Unique rollout identifier (logged into rollout trees).
|
| 46 |
+
agents:
|
| 47 |
+
Mapping of agent_id -> Agent instance.
|
| 48 |
+
simulation:
|
| 49 |
+
Environment implementing the ``Simulation`` interface (IPD, TAS, etc.).
|
| 50 |
+
crn_id:
|
| 51 |
+
Identifier for the common random number stream used by this rollout.
|
| 52 |
"""
|
| 53 |
self.agents = agents
|
| 54 |
self.agent_ids = self.agents.keys()
|
|
|
|
| 129 |
|
| 130 |
async def set_action_of_agent(self, agent_id: AgentId):
|
| 131 |
"""
|
| 132 |
+
Query a single agent for its next action and store the result locally.
|
| 133 |
"""
|
| 134 |
agent = self.agents[agent_id]
|
| 135 |
obs = self.simulation.get_obs_agent(agent_id)
|
|
|
|
| 139 |
|
| 140 |
async def set_actions(self):
|
| 141 |
"""
|
| 142 |
+
Query every agent concurrently and populate the cached actions/logs.
|
| 143 |
"""
|
| 144 |
# background_tasks = set()
|
| 145 |
tasks = []
|
|
|
|
| 150 |
|
| 151 |
def take_simulation_step(self):
|
| 152 |
"""
|
| 153 |
+
Advance the simulation by one step using the cached actions.
|
| 154 |
"""
|
| 155 |
terminated, self.simulation_step_log = self.simulation.step(self.actions)
|
| 156 |
return terminated
|
| 157 |
|
| 158 |
def get_step_log(self) -> StepLog:
|
| 159 |
"""
|
| 160 |
+
Package the most recent simulation step and agent logs into a StepLog.
|
|
|
|
| 161 |
"""
|
| 162 |
+
if self.simulation_step_log is None:
|
| 163 |
+
raise RuntimeError(
|
| 164 |
+
"Simulation step log is empty; call take_simulation_step() first."
|
| 165 |
+
)
|
| 166 |
+
missing_logs = [
|
| 167 |
+
agent_id for agent_id, log in self.agent_step_logs.items() if log is None
|
| 168 |
+
]
|
| 169 |
+
if missing_logs:
|
| 170 |
+
raise RuntimeError(
|
| 171 |
+
f"Agent action logs missing for: {', '.join(missing_logs)}. "
|
| 172 |
+
"Ensure set_actions() ran before requesting the step log."
|
| 173 |
+
)
|
| 174 |
step_log = StepLog(
|
| 175 |
simulation_step_log=self.simulation_step_log,
|
| 176 |
action_logs=self.agent_step_logs,
|
|
|
|
| 179 |
|
| 180 |
async def step(self) -> Tuple[bool, StepLog]:
|
| 181 |
"""
|
| 182 |
+
Convenience step that collects actions, advances the simulation, and returns the log.
|
| 183 |
"""
|
| 184 |
await self.set_actions()
|
| 185 |
terminated = self.take_simulation_step()
|
|
|
|
| 188 |
|
| 189 |
def get_safe_copy(self):
|
| 190 |
"""
|
| 191 |
+
Create a shallow copy of the game with deep-copied agents/simulation for branching.
|
| 192 |
"""
|
| 193 |
|
| 194 |
new_markov_game = copy.copy(self)
|
src_code_for_reproducibility/markov_games/mg_utils.py
CHANGED
|
@@ -1,9 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import copy
|
| 3 |
from collections.abc import Callable
|
| 4 |
from dataclasses import dataclass
|
| 5 |
|
| 6 |
from mllm.markov_games.ipd.ipd_agent import IPDAgent
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from mllm.markov_games.ipd.ipd_simulation import IPD
|
| 8 |
from mllm.markov_games.markov_game import MarkovGame
|
| 9 |
from mllm.markov_games.negotiation.dond_agent import DealNoDealAgent
|
|
@@ -12,17 +21,10 @@ from mllm.markov_games.negotiation.nego_hard_coded_policies import (
|
|
| 12 |
HardCodedNegoGreedyPolicy,
|
| 13 |
HardCodedNegoWelfareMaximizingPolicy,
|
| 14 |
)
|
| 15 |
-
from mllm.markov_games.ipd.Ipd_hard_coded_agents import AlwaysCooperateIPDAgent, AlwaysDefectIPDAgent
|
| 16 |
from mllm.markov_games.negotiation.no_press_nego_agent import NoPressAgent
|
| 17 |
from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressSimulation
|
| 18 |
-
from mllm.markov_games.negotiation.tas_agent import TrustAndSplitAgent
|
| 19 |
from mllm.markov_games.negotiation.tas_rps_agent import TrustAndSplitRPSAgent
|
| 20 |
from mllm.markov_games.negotiation.tas_rps_simulation import TrustAndSplitRPSSimulation
|
| 21 |
-
from mllm.markov_games.negotiation.tas_simple_agent import TrustAndSplitSimpleAgent
|
| 22 |
-
from mllm.markov_games.negotiation.tas_simple_simulation import (
|
| 23 |
-
TrustAndSplitSimpleSimulation,
|
| 24 |
-
)
|
| 25 |
-
from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitSimulation
|
| 26 |
from mllm.markov_games.rollout_tree import (
|
| 27 |
AgentActLog,
|
| 28 |
RolloutTreeBranchNode,
|
|
@@ -37,6 +39,8 @@ AgentId = str
|
|
| 37 |
|
| 38 |
@dataclass
|
| 39 |
class AgentConfig:
|
|
|
|
|
|
|
| 40 |
agent_id: str
|
| 41 |
agent_name: str
|
| 42 |
agent_class_name: str
|
|
@@ -46,6 +50,8 @@ class AgentConfig:
|
|
| 46 |
|
| 47 |
@dataclass
|
| 48 |
class MarkovGameConfig:
|
|
|
|
|
|
|
| 49 |
id: int
|
| 50 |
seed: int
|
| 51 |
simulation_class_name: str
|
|
@@ -57,7 +63,9 @@ def init_markov_game_components(
|
|
| 57 |
config: MarkovGameConfig, policies: dict[str, Callable[[list[dict]], str]]
|
| 58 |
):
|
| 59 |
"""
|
| 60 |
-
|
|
|
|
|
|
|
| 61 |
"""
|
| 62 |
agents = {}
|
| 63 |
agent_names = []
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/markov_games/mg_utils.py
|
| 3 |
+
Summary: Holds miscellaneous helpers shared across Markov-game modules.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import asyncio
|
| 7 |
import copy
|
| 8 |
from collections.abc import Callable
|
| 9 |
from dataclasses import dataclass
|
| 10 |
|
| 11 |
from mllm.markov_games.ipd.ipd_agent import IPDAgent
|
| 12 |
+
from mllm.markov_games.ipd.Ipd_hard_coded_agents import (
|
| 13 |
+
AlwaysCooperateIPDAgent,
|
| 14 |
+
AlwaysDefectIPDAgent,
|
| 15 |
+
)
|
| 16 |
from mllm.markov_games.ipd.ipd_simulation import IPD
|
| 17 |
from mllm.markov_games.markov_game import MarkovGame
|
| 18 |
from mllm.markov_games.negotiation.dond_agent import DealNoDealAgent
|
|
|
|
| 21 |
HardCodedNegoGreedyPolicy,
|
| 22 |
HardCodedNegoWelfareMaximizingPolicy,
|
| 23 |
)
|
|
|
|
| 24 |
from mllm.markov_games.negotiation.no_press_nego_agent import NoPressAgent
|
| 25 |
from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressSimulation
|
|
|
|
| 26 |
from mllm.markov_games.negotiation.tas_rps_agent import TrustAndSplitRPSAgent
|
| 27 |
from mllm.markov_games.negotiation.tas_rps_simulation import TrustAndSplitRPSSimulation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
from mllm.markov_games.rollout_tree import (
|
| 29 |
AgentActLog,
|
| 30 |
RolloutTreeBranchNode,
|
|
|
|
| 39 |
|
| 40 |
@dataclass
|
| 41 |
class AgentConfig:
|
| 42 |
+
"""Configuration blob describing one agent in a Markov game spec."""
|
| 43 |
+
|
| 44 |
agent_id: str
|
| 45 |
agent_name: str
|
| 46 |
agent_class_name: str
|
|
|
|
| 50 |
|
| 51 |
@dataclass
|
| 52 |
class MarkovGameConfig:
|
| 53 |
+
"""Top-level config that ties together simulation settings and agent configs."""
|
| 54 |
+
|
| 55 |
id: int
|
| 56 |
seed: int
|
| 57 |
simulation_class_name: str
|
|
|
|
| 63 |
config: MarkovGameConfig, policies: dict[str, Callable[[list[dict]], str]]
|
| 64 |
):
|
| 65 |
"""
|
| 66 |
+
Materialize Agents and the Simulation described by ``config`` and return a MarkovGame.
|
| 67 |
+
|
| 68 |
+
`policies` is a mapping of policy_id -> callable retrieved from the hosting trainer.
|
| 69 |
"""
|
| 70 |
agents = {}
|
| 71 |
agent_names = []
|
src_code_for_reproducibility/markov_games/negotiation/dond_simulation.py
CHANGED
|
@@ -1,30 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import copy
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from typing import Any, Dict, List, Tuple
|
| 4 |
|
| 5 |
from numpy.random import default_rng
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from mllm.markov_games.rollout_tree import SimulationStepLog
|
| 8 |
-
from mllm.markov_games.negotiation.nego_simulation import Split, NegotiationState, NegotiationObs, NegotiationSimulation
|
| 9 |
from mllm.utils.get_coagent_id import get_coagent_id
|
| 10 |
|
| 11 |
-
|
| 12 |
AgentId = str
|
| 13 |
|
| 14 |
|
| 15 |
@dataclass
|
| 16 |
class DealNoDealState(NegotiationState):
|
|
|
|
|
|
|
| 17 |
item_types: List[str]
|
| 18 |
values: Dict[AgentId, Dict[str, int]]
|
| 19 |
|
|
|
|
| 20 |
@dataclass
|
| 21 |
class DealNoDealObs(NegotiationObs):
|
|
|
|
|
|
|
| 22 |
my_values: Dict[str, int]
|
| 23 |
item_types: List[str]
|
| 24 |
previous_values_coagent: Dict[str, int] | None
|
| 25 |
|
| 26 |
|
| 27 |
def random_partition_integer(rng, total: int, parts: int) -> List[int]:
|
|
|
|
| 28 |
if parts <= 0:
|
| 29 |
return []
|
| 30 |
if total <= 0:
|
|
@@ -37,7 +52,9 @@ def random_partition_integer(rng, total: int, parts: int) -> List[int]:
|
|
| 37 |
prev = c
|
| 38 |
return vals
|
| 39 |
|
|
|
|
| 40 |
class DealNoDealSimulation(NegotiationSimulation):
|
|
|
|
| 41 |
|
| 42 |
def __init__(
|
| 43 |
self,
|
|
@@ -75,7 +92,9 @@ class DealNoDealSimulation(NegotiationSimulation):
|
|
| 75 |
if ok1 and ok2:
|
| 76 |
return {self.agent_ids[0]: a, self.agent_ids[1]: b}
|
| 77 |
|
| 78 |
-
def _is_valid_allocation(
|
|
|
|
|
|
|
| 79 |
for t in self.item_types:
|
| 80 |
v = allocation.get(t)
|
| 81 |
if v is None:
|
|
@@ -85,16 +104,18 @@ class DealNoDealSimulation(NegotiationSimulation):
|
|
| 85 |
if v < 0 or v > int(stock.get(t, 0)):
|
| 86 |
return False
|
| 87 |
return True
|
| 88 |
-
|
| 89 |
def set_new_round_of_variant(self):
|
| 90 |
# Keep same values, resample stock
|
| 91 |
self.state.quantities = self._sample_stock()
|
| 92 |
|
| 93 |
-
def get_info_of_variant(
|
|
|
|
|
|
|
| 94 |
return {
|
| 95 |
"quantities": copy.deepcopy(state.quantities),
|
| 96 |
"values": copy.deepcopy(state.values),
|
| 97 |
-
|
| 98 |
}
|
| 99 |
|
| 100 |
def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]:
|
|
@@ -105,11 +126,15 @@ class DealNoDealSimulation(NegotiationSimulation):
|
|
| 105 |
split_b = splits[self.agent_ids[1]].items_given_to_self
|
| 106 |
rewards = {self.agent_ids[0]: 0, self.agent_ids[1]: 0}
|
| 107 |
for t in self.item_types:
|
| 108 |
-
# If not complementary, return 0!
|
| 109 |
if not split_a[t] + split_b[t] == self.state.quantities[t]:
|
| 110 |
return {self.agent_ids[0]: 0, self.agent_ids[1]: 0}
|
| 111 |
-
rewards[self.agent_ids[0]] +=
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
return rewards
|
| 114 |
|
| 115 |
def get_obs(self):
|
|
@@ -149,5 +174,3 @@ class DealNoDealSimulation(NegotiationSimulation):
|
|
| 149 |
item_types=list(self.item_types),
|
| 150 |
)
|
| 151 |
return self.get_obs()
|
| 152 |
-
|
| 153 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/markov_games/negotiation/dond_simulation.py
|
| 3 |
+
Summary: Simulates Deal-or-No-Deal negotiation games and logs rollouts.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import copy
|
| 7 |
from dataclasses import dataclass
|
| 8 |
from typing import Any, Dict, List, Tuple
|
| 9 |
|
| 10 |
from numpy.random import default_rng
|
| 11 |
|
| 12 |
+
from mllm.markov_games.negotiation.nego_simulation import (
|
| 13 |
+
NegotiationObs,
|
| 14 |
+
NegotiationSimulation,
|
| 15 |
+
NegotiationState,
|
| 16 |
+
Split,
|
| 17 |
+
)
|
| 18 |
from mllm.markov_games.rollout_tree import SimulationStepLog
|
|
|
|
| 19 |
from mllm.utils.get_coagent_id import get_coagent_id
|
| 20 |
|
|
|
|
| 21 |
AgentId = str
|
| 22 |
|
| 23 |
|
| 24 |
@dataclass
|
| 25 |
class DealNoDealState(NegotiationState):
|
| 26 |
+
"""NegotiationState with per-agent value tables and item taxonomy."""
|
| 27 |
+
|
| 28 |
item_types: List[str]
|
| 29 |
values: Dict[AgentId, Dict[str, int]]
|
| 30 |
|
| 31 |
+
|
| 32 |
@dataclass
|
| 33 |
class DealNoDealObs(NegotiationObs):
|
| 34 |
+
"""Observation that reveals own values and (lagged) opponent values."""
|
| 35 |
+
|
| 36 |
my_values: Dict[str, int]
|
| 37 |
item_types: List[str]
|
| 38 |
previous_values_coagent: Dict[str, int] | None
|
| 39 |
|
| 40 |
|
| 41 |
def random_partition_integer(rng, total: int, parts: int) -> List[int]:
|
| 42 |
+
"""Sample non-negative integers summing to ``total`` across ``parts`` buckets."""
|
| 43 |
if parts <= 0:
|
| 44 |
return []
|
| 45 |
if total <= 0:
|
|
|
|
| 52 |
prev = c
|
| 53 |
return vals
|
| 54 |
|
| 55 |
+
|
| 56 |
class DealNoDealSimulation(NegotiationSimulation):
|
| 57 |
+
"""NegotiationSimulation variant implementing the Rubinstein-style Deal-or-No-Deal."""
|
| 58 |
|
| 59 |
def __init__(
|
| 60 |
self,
|
|
|
|
| 92 |
if ok1 and ok2:
|
| 93 |
return {self.agent_ids[0]: a, self.agent_ids[1]: b}
|
| 94 |
|
| 95 |
+
def _is_valid_allocation(
|
| 96 |
+
self, allocation: Dict[str, int], stock: Dict[str, int]
|
| 97 |
+
) -> bool:
|
| 98 |
for t in self.item_types:
|
| 99 |
v = allocation.get(t)
|
| 100 |
if v is None:
|
|
|
|
| 104 |
if v < 0 or v > int(stock.get(t, 0)):
|
| 105 |
return False
|
| 106 |
return True
|
| 107 |
+
|
| 108 |
def set_new_round_of_variant(self):
|
| 109 |
# Keep same values, resample stock
|
| 110 |
self.state.quantities = self._sample_stock()
|
| 111 |
|
| 112 |
+
def get_info_of_variant(
|
| 113 |
+
self, state: NegotiationState, actions: Dict[AgentId, Any]
|
| 114 |
+
) -> Dict[str, Any]:
|
| 115 |
return {
|
| 116 |
"quantities": copy.deepcopy(state.quantities),
|
| 117 |
"values": copy.deepcopy(state.values),
|
| 118 |
+
"splits": copy.deepcopy(state.splits),
|
| 119 |
}
|
| 120 |
|
| 121 |
def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]:
|
|
|
|
| 126 |
split_b = splits[self.agent_ids[1]].items_given_to_self
|
| 127 |
rewards = {self.agent_ids[0]: 0, self.agent_ids[1]: 0}
|
| 128 |
for t in self.item_types:
|
| 129 |
+
# If not complementary, return 0!
|
| 130 |
if not split_a[t] + split_b[t] == self.state.quantities[t]:
|
| 131 |
return {self.agent_ids[0]: 0, self.agent_ids[1]: 0}
|
| 132 |
+
rewards[self.agent_ids[0]] += (
|
| 133 |
+
split_a[t] * self.state.values[self.agent_ids[0]][t]
|
| 134 |
+
)
|
| 135 |
+
rewards[self.agent_ids[1]] += (
|
| 136 |
+
split_b[t] * self.state.values[self.agent_ids[1]][t]
|
| 137 |
+
)
|
| 138 |
return rewards
|
| 139 |
|
| 140 |
def get_obs(self):
|
|
|
|
| 174 |
item_types=list(self.item_types),
|
| 175 |
)
|
| 176 |
return self.get_obs()
|
|
|
|
|
|
src_code_for_reproducibility/markov_games/negotiation/nego_simulation.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
"""
|
|
|
|
| 5 |
import copy
|
| 6 |
from abc import abstractmethod
|
| 7 |
from dataclasses import dataclass
|
|
@@ -18,16 +19,22 @@ AgentId = str
|
|
| 18 |
|
| 19 |
@dataclass
|
| 20 |
class Split:
|
|
|
|
|
|
|
| 21 |
items_given_to_self: Dict[str, int]
|
| 22 |
|
| 23 |
|
| 24 |
@dataclass
|
| 25 |
class Message:
|
|
|
|
|
|
|
| 26 |
message: str
|
| 27 |
|
| 28 |
|
| 29 |
@dataclass # gets extended by variants
|
| 30 |
class NegotiationState:
|
|
|
|
|
|
|
| 31 |
round_nb: int
|
| 32 |
last_message: str
|
| 33 |
current_agent: AgentId
|
|
@@ -44,6 +51,8 @@ class NegotiationState:
|
|
| 44 |
|
| 45 |
@dataclass # gets extended by variants
|
| 46 |
class NegotiationObs:
|
|
|
|
|
|
|
| 47 |
round_nb: int
|
| 48 |
last_message: str
|
| 49 |
quota_messages_per_agent_per_round: int
|
|
@@ -134,12 +143,14 @@ class NegotiationSimulation(Simulation):
|
|
| 134 |
|
| 135 |
@abstractmethod
|
| 136 |
def set_new_round_of_variant(self):
|
|
|
|
| 137 |
pass
|
| 138 |
|
| 139 |
@abstractmethod
|
| 140 |
def get_info_of_variant(
|
| 141 |
self, state: NegotiationState, actions: Dict[AgentId, Any]
|
| 142 |
) -> Dict[str, Any]:
|
|
|
|
| 143 |
pass
|
| 144 |
|
| 145 |
def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]:
|
|
@@ -190,7 +201,7 @@ class NegotiationSimulation(Simulation):
|
|
| 190 |
is_last_timestep_in_round = True
|
| 191 |
done = self.state.round_nb >= self.nb_of_rounds
|
| 192 |
|
| 193 |
-
# Message phase
|
| 194 |
elif isinstance(action, Message):
|
| 195 |
self.state.last_message = action.message
|
| 196 |
self.state.nb_messages_sent[current_agent] += 1
|
|
|
|
| 1 |
"""
|
| 2 |
+
File: mllm/markov_games/negotiation/nego_simulation.py
|
| 3 |
+
Summary: Simulation harness for general negotiation environments.
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import copy
|
| 7 |
from abc import abstractmethod
|
| 8 |
from dataclasses import dataclass
|
|
|
|
| 19 |
|
| 20 |
@dataclass
|
| 21 |
class Split:
|
| 22 |
+
"""Structured proposal describing how many units of each item an agent keeps."""
|
| 23 |
+
|
| 24 |
items_given_to_self: Dict[str, int]
|
| 25 |
|
| 26 |
|
| 27 |
@dataclass
|
| 28 |
class Message:
|
| 29 |
+
"""Single chat utterance exchanged during the negotiation phase."""
|
| 30 |
+
|
| 31 |
message: str
|
| 32 |
|
| 33 |
|
| 34 |
@dataclass # gets extended by variants
|
| 35 |
class NegotiationState:
|
| 36 |
+
"""Full simulator state snapshot shared by all negotiation variants."""
|
| 37 |
+
|
| 38 |
round_nb: int
|
| 39 |
last_message: str
|
| 40 |
current_agent: AgentId
|
|
|
|
| 51 |
|
| 52 |
@dataclass # gets extended by variants
|
| 53 |
class NegotiationObs:
|
| 54 |
+
"""Observation presented to agents each turn (base fields; variants extend)."""
|
| 55 |
+
|
| 56 |
round_nb: int
|
| 57 |
last_message: str
|
| 58 |
quota_messages_per_agent_per_round: int
|
|
|
|
| 143 |
|
| 144 |
@abstractmethod
|
| 145 |
def set_new_round_of_variant(self):
|
| 146 |
+
"""Variant hook: sample new private values / stock before each round."""
|
| 147 |
pass
|
| 148 |
|
| 149 |
@abstractmethod
|
| 150 |
def get_info_of_variant(
|
| 151 |
self, state: NegotiationState, actions: Dict[AgentId, Any]
|
| 152 |
) -> Dict[str, Any]:
|
| 153 |
+
"""Variant hook: populate SimulationStepLog.info with custom diagnostics."""
|
| 154 |
pass
|
| 155 |
|
| 156 |
def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]:
|
|
|
|
| 201 |
is_last_timestep_in_round = True
|
| 202 |
done = self.state.round_nb >= self.nb_of_rounds
|
| 203 |
|
| 204 |
+
# Message phase: roll the conversation forward a single turn.
|
| 205 |
elif isinstance(action, Message):
|
| 206 |
self.state.last_message = action.message
|
| 207 |
self.state.nb_messages_sent[current_agent] += 1
|
src_code_for_reproducibility/markov_games/negotiation/tas_agent.py
CHANGED
|
@@ -1,9 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from mllm.markov_games.negotiation.nego_agent import NegotiationAgent
|
| 2 |
from mllm.markov_games.negotiation.nego_simulation import Split
|
| 3 |
from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitObs
|
| 4 |
|
| 5 |
|
| 6 |
class TrustAndSplitAgent(NegotiationAgent):
|
|
|
|
|
|
|
| 7 |
def __init__(self, num_message_chars, *args, **kwargs):
|
| 8 |
self.num_message_chars = num_message_chars
|
| 9 |
super().__init__(*args, **kwargs)
|
|
@@ -58,12 +65,14 @@ class TrustAndSplitAgent(NegotiationAgent):
|
|
| 58 |
self.send_message_prompt = f"Send your message now in <message>...</message> (<={self.num_message_chars} chars)."
|
| 59 |
|
| 60 |
def get_message_regex(self, observation: TrustAndSplitObs) -> str:
|
|
|
|
| 61 |
return rf"<message>[\s\S]{{0,{self.num_message_chars}}}</message>"
|
| 62 |
|
| 63 |
# def get_message_regex(self, observation: TrustAndSplitObs) -> str:
|
| 64 |
# return rf"(?s).{{0,{self.num_message_chars}}}"
|
| 65 |
|
| 66 |
def get_split_regex(self, observation: TrustAndSplitObs) -> str:
|
|
|
|
| 67 |
items = list(observation.quantities.keys())
|
| 68 |
# Accept both singular and plural forms
|
| 69 |
item_pattern = "|".join(
|
|
@@ -75,6 +84,7 @@ class TrustAndSplitAgent(NegotiationAgent):
|
|
| 75 |
def get_split_action(
|
| 76 |
self, policy_output: str, observation: TrustAndSplitObs
|
| 77 |
) -> Split:
|
|
|
|
| 78 |
items = list(observation.quantities.keys())
|
| 79 |
import re as _re
|
| 80 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/markov_games/negotiation/tas_agent.py
|
| 3 |
+
Summary: Agent implementation for Take-and-Split negotiations.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
from mllm.markov_games.negotiation.nego_agent import NegotiationAgent
|
| 7 |
from mllm.markov_games.negotiation.nego_simulation import Split
|
| 8 |
from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitObs
|
| 9 |
|
| 10 |
|
| 11 |
class TrustAndSplitAgent(NegotiationAgent):
|
| 12 |
+
"""Prompt/template wrapper for the classic multi-item Take-and-Split benchmark."""
|
| 13 |
+
|
| 14 |
def __init__(self, num_message_chars, *args, **kwargs):
|
| 15 |
self.num_message_chars = num_message_chars
|
| 16 |
super().__init__(*args, **kwargs)
|
|
|
|
| 65 |
self.send_message_prompt = f"Send your message now in <message>...</message> (<={self.num_message_chars} chars)."
|
| 66 |
|
| 67 |
def get_message_regex(self, observation: TrustAndSplitObs) -> str:
|
| 68 |
+
"""Constrain chat to bounded XML tags for stable parsing."""
|
| 69 |
return rf"<message>[\s\S]{{0,{self.num_message_chars}}}</message>"
|
| 70 |
|
| 71 |
# def get_message_regex(self, observation: TrustAndSplitObs) -> str:
|
| 72 |
# return rf"(?s).{{0,{self.num_message_chars}}}"
|
| 73 |
|
| 74 |
def get_split_regex(self, observation: TrustAndSplitObs) -> str:
|
| 75 |
+
"""Allow natural-language item names while still returning machine-parsable XML."""
|
| 76 |
items = list(observation.quantities.keys())
|
| 77 |
# Accept both singular and plural forms
|
| 78 |
item_pattern = "|".join(
|
|
|
|
| 84 |
def get_split_action(
|
| 85 |
self, policy_output: str, observation: TrustAndSplitObs
|
| 86 |
) -> Split:
|
| 87 |
+
"""Convert human-readable allocation text back into canonical item IDs."""
|
| 88 |
items = list(observation.quantities.keys())
|
| 89 |
import re as _re
|
| 90 |
|
src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import copy
|
| 2 |
from collections.abc import Callable
|
| 3 |
from dataclasses import dataclass
|
|
@@ -15,6 +20,8 @@ from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn
|
|
| 15 |
|
| 16 |
|
| 17 |
class TrustAndSplitRPSAgent(NegotiationAgent):
|
|
|
|
|
|
|
| 18 |
def __init__(
|
| 19 |
self,
|
| 20 |
num_message_chars: int,
|
|
@@ -88,6 +95,7 @@ class TrustAndSplitRPSAgent(NegotiationAgent):
|
|
| 88 |
self.send_message_prompt = f"Send your message now in <message>...</message> (<={self.num_message_chars} chars)."
|
| 89 |
|
| 90 |
def get_message_regex(self, observation: TrustAndSplitRPSObs) -> str:
|
|
|
|
| 91 |
if self.message_start_end_format:
|
| 92 |
return (
|
| 93 |
rf"<<message_start>>[\s\S]{{0,{self.num_message_chars}}}<<message_end>>"
|
|
@@ -96,6 +104,7 @@ class TrustAndSplitRPSAgent(NegotiationAgent):
|
|
| 96 |
return rf"<message>[\s\S]{{0,{self.num_message_chars}}}</message>"
|
| 97 |
|
| 98 |
def get_split_regex(self, observation: TrustAndSplitRPSObs) -> str:
|
|
|
|
| 99 |
if self.proposal_start_end_format:
|
| 100 |
return r"<<proposal_start>> ?(10|[0-9]) ?<<proposal_end>>"
|
| 101 |
else:
|
|
@@ -104,6 +113,7 @@ class TrustAndSplitRPSAgent(NegotiationAgent):
|
|
| 104 |
def get_split_action(
|
| 105 |
self, policy_output: str, observation: TrustAndSplitRPSObs
|
| 106 |
) -> Split:
|
|
|
|
| 107 |
import re as _re
|
| 108 |
|
| 109 |
if self.proposal_start_end_format:
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/markov_games/negotiation/tas_rps_agent.py
|
| 3 |
+
Summary: Agent logic for TAS Rock-Paper-Scissors blended game.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import copy
|
| 7 |
from collections.abc import Callable
|
| 8 |
from dataclasses import dataclass
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
class TrustAndSplitRPSAgent(NegotiationAgent):
|
| 23 |
+
"""NegotiationAgent that reasons about hidden hands before submitting TAS splits."""
|
| 24 |
+
|
| 25 |
def __init__(
|
| 26 |
self,
|
| 27 |
num_message_chars: int,
|
|
|
|
| 95 |
self.send_message_prompt = f"Send your message now in <message>...</message> (<={self.num_message_chars} chars)."
|
| 96 |
|
| 97 |
def get_message_regex(self, observation: TrustAndSplitRPSObs) -> str:
|
| 98 |
+
"""Switch between <message>...</message> and <<message_start>> formats on demand."""
|
| 99 |
if self.message_start_end_format:
|
| 100 |
return (
|
| 101 |
rf"<<message_start>>[\s\S]{{0,{self.num_message_chars}}}<<message_end>>"
|
|
|
|
| 104 |
return rf"<message>[\s\S]{{0,{self.num_message_chars}}}</message>"
|
| 105 |
|
| 106 |
def get_split_regex(self, observation: TrustAndSplitRPSObs) -> str:
|
| 107 |
+
"""Force single-number proposals inside whichever tag style the config selected."""
|
| 108 |
if self.proposal_start_end_format:
|
| 109 |
return r"<<proposal_start>> ?(10|[0-9]) ?<<proposal_end>>"
|
| 110 |
else:
|
|
|
|
| 113 |
def get_split_action(
|
| 114 |
self, policy_output: str, observation: TrustAndSplitRPSObs
|
| 115 |
) -> Split:
|
| 116 |
+
"""Parse the proposal tag (or raw integer fallback) into a Split."""
|
| 117 |
import re as _re
|
| 118 |
|
| 119 |
if self.proposal_start_end_format:
|
src_code_for_reproducibility/markov_games/rollout_tree.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
-
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
from __future__ import annotations
|
|
@@ -18,11 +19,15 @@ AgentId = str
|
|
| 18 |
|
| 19 |
|
| 20 |
class SimulationStepLog(BaseModel):
|
|
|
|
|
|
|
| 21 |
rewards: dict[AgentId, float]
|
| 22 |
info: Any = None
|
| 23 |
|
| 24 |
|
| 25 |
class AgentActLog(BaseModel):
|
|
|
|
|
|
|
| 26 |
chat_turns: list[ChatTurn] | None
|
| 27 |
info: Any = None
|
| 28 |
|
|
@@ -55,6 +60,8 @@ class StepLog(BaseModel):
|
|
| 55 |
|
| 56 |
|
| 57 |
class RolloutTreeNode(BaseModel):
|
|
|
|
|
|
|
| 58 |
step_log: StepLog
|
| 59 |
time_step: int
|
| 60 |
child: RolloutTreeNode | RolloutTreeBranchNode | None = None
|
|
@@ -70,6 +77,8 @@ class RolloutTreeBranchNode(BaseModel):
|
|
| 70 |
|
| 71 |
|
| 72 |
class RolloutTreeRootNode(BaseModel):
|
|
|
|
|
|
|
| 73 |
id: int
|
| 74 |
crn_id: int # ID of the rng used to generate this rollout tree
|
| 75 |
child: RolloutTreeNode | RolloutTreeBranchNode | None = None
|
|
|
|
| 1 |
"""
|
| 2 |
+
File: mllm/markov_games/rollout_tree.py
|
| 3 |
+
Summary: Defines rollout tree data structures and serialization helpers.
|
| 4 |
"""
|
| 5 |
|
| 6 |
from __future__ import annotations
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
class SimulationStepLog(BaseModel):
|
| 22 |
+
"""Minimal snapshot of environment-side rewards and auxiliary info."""
|
| 23 |
+
|
| 24 |
rewards: dict[AgentId, float]
|
| 25 |
info: Any = None
|
| 26 |
|
| 27 |
|
| 28 |
class AgentActLog(BaseModel):
|
| 29 |
+
"""LLM-side provenance for an action (chat turns + metadata)."""
|
| 30 |
+
|
| 31 |
chat_turns: list[ChatTurn] | None
|
| 32 |
info: Any = None
|
| 33 |
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
class RolloutTreeNode(BaseModel):
|
| 63 |
+
"""Single timestep of the main trajectory (or a branch) plus linkage."""
|
| 64 |
+
|
| 65 |
step_log: StepLog
|
| 66 |
time_step: int
|
| 67 |
child: RolloutTreeNode | RolloutTreeBranchNode | None = None
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
class RolloutTreeRootNode(BaseModel):
|
| 80 |
+
"""Entry point for serialized rollouts (main path plus optional branches)."""
|
| 81 |
+
|
| 82 |
id: int
|
| 83 |
crn_id: int # ID of the rng used to generate this rollout tree
|
| 84 |
child: RolloutTreeNode | RolloutTreeBranchNode | None = None
|
src_code_for_reproducibility/markov_games/run_markov_games.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
from collections.abc import Callable
|
| 3 |
from dataclasses import dataclass
|
|
@@ -14,6 +19,12 @@ async def run_markov_games(
|
|
| 14 |
output_folder: str,
|
| 15 |
markov_games: list[MarkovGame],
|
| 16 |
) -> list[RolloutTreeRootNode]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
tasks = []
|
| 18 |
for mg in markov_games:
|
| 19 |
tasks.append(
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/markov_games/run_markov_games.py
|
| 3 |
+
Summary: CLI entry point for running configured Markov-game experiments.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import asyncio
|
| 7 |
from collections.abc import Callable
|
| 8 |
from dataclasses import dataclass
|
|
|
|
| 19 |
output_folder: str,
|
| 20 |
markov_games: list[MarkovGame],
|
| 21 |
) -> list[RolloutTreeRootNode]:
|
| 22 |
+
"""
|
| 23 |
+
Kick off multiple Markov game rollouts concurrently and return their trees.
|
| 24 |
+
|
| 25 |
+
Parameters mirror the Hydra configs (runner callable + kwargs) so callers can
|
| 26 |
+
choose ``LinearRunner``, ``AlternativeActionsRunner`` or future variants.
|
| 27 |
+
"""
|
| 28 |
tasks = []
|
| 29 |
for mg in markov_games:
|
| 30 |
tasks.append(
|
src_code_for_reproducibility/markov_games/simulation.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
This is the job of the `Agent` class.
|
| 5 |
-
Simulations expect clean actions, and are defined similarly to `gymnasium` environments, except that they are adapted for the Multi-agent setting.
|
| 6 |
"""
|
| 7 |
|
| 8 |
from abc import ABC, abstractmethod
|
|
@@ -22,59 +20,68 @@ class Simulation(ABC):
|
|
| 22 |
@abstractmethod
|
| 23 |
def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]:
|
| 24 |
"""
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
"""
|
| 27 |
raise NotImplementedError
|
| 28 |
|
| 29 |
def get_obs(self):
|
| 30 |
-
"""
|
| 31 |
-
|
| 32 |
-
Returns:
|
| 33 |
-
observations
|
| 34 |
-
"""
|
| 35 |
raise NotImplementedError
|
| 36 |
|
| 37 |
def get_obs_agent(self, agent_id):
|
| 38 |
-
"""
|
| 39 |
raise NotImplementedError
|
| 40 |
|
| 41 |
def get_obs_size(self):
|
| 42 |
-
"""
|
| 43 |
raise NotImplementedError
|
| 44 |
|
| 45 |
def get_state(self):
|
|
|
|
| 46 |
raise NotImplementedError
|
| 47 |
|
| 48 |
def get_state_size(self):
|
| 49 |
-
"""
|
| 50 |
raise NotImplementedError
|
| 51 |
|
| 52 |
def get_avail_actions(self):
|
|
|
|
| 53 |
raise NotImplementedError
|
| 54 |
|
| 55 |
def get_avail_agent_actions(self, agent_id):
|
| 56 |
-
"""
|
| 57 |
raise NotImplementedError
|
| 58 |
|
| 59 |
def get_total_actions(self):
|
| 60 |
-
"""Returns the total number of actions an agent could ever take
|
| 61 |
-
|
|
|
|
|
|
|
| 62 |
raise NotImplementedError
|
| 63 |
|
| 64 |
def get_safe_copy(self):
|
| 65 |
"""
|
| 66 |
-
Return copy of the
|
| 67 |
"""
|
| 68 |
raise NotImplementedError
|
| 69 |
|
| 70 |
def reset(self):
|
| 71 |
-
"""
|
| 72 |
raise NotImplementedError
|
| 73 |
|
| 74 |
def render(self):
|
|
|
|
| 75 |
raise NotImplementedError
|
| 76 |
|
| 77 |
def close(self):
|
|
|
|
| 78 |
raise NotImplementedError
|
| 79 |
|
| 80 |
# def seed(self):
|
|
|
|
| 1 |
"""
|
| 2 |
+
File: mllm/markov_games/simulation.py
|
| 3 |
+
Summary: Core simulation loop utilities and step logging for Markov games.
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from abc import ABC, abstractmethod
|
|
|
|
| 20 |
@abstractmethod
|
| 21 |
def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]:
|
| 22 |
"""
|
| 23 |
+
Advance the environment by one logical tick using ``actions``.
|
| 24 |
+
|
| 25 |
+
Returns
|
| 26 |
+
-------
|
| 27 |
+
terminated: bool
|
| 28 |
+
Whether the episode has finished.
|
| 29 |
+
SimulationStepLog
|
| 30 |
+
Reward/info bundle describing this transition.
|
| 31 |
"""
|
| 32 |
raise NotImplementedError
|
| 33 |
|
| 34 |
def get_obs(self):
|
| 35 |
+
"""Return a dict mapping agent_id -> observation for *all* agents."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
raise NotImplementedError
|
| 37 |
|
| 38 |
def get_obs_agent(self, agent_id):
|
| 39 |
+
"""Return the observation for a single agent."""
|
| 40 |
raise NotImplementedError
|
| 41 |
|
| 42 |
def get_obs_size(self):
|
| 43 |
+
"""Describe the observation tensor shape (useful for critic heads)."""
|
| 44 |
raise NotImplementedError
|
| 45 |
|
| 46 |
def get_state(self):
|
| 47 |
+
"""Return the privileged simulator state if available."""
|
| 48 |
raise NotImplementedError
|
| 49 |
|
| 50 |
def get_state_size(self):
|
| 51 |
+
"""Describe the state tensor shape."""
|
| 52 |
raise NotImplementedError
|
| 53 |
|
| 54 |
def get_avail_actions(self):
|
| 55 |
+
"""Return the global action mask/tensor if the space is discrete."""
|
| 56 |
raise NotImplementedError
|
| 57 |
|
| 58 |
def get_avail_agent_actions(self, agent_id):
|
| 59 |
+
"""Return the available action mask for a given agent."""
|
| 60 |
raise NotImplementedError
|
| 61 |
|
| 62 |
def get_total_actions(self):
|
| 63 |
+
"""Returns the total number of actions an agent could ever take.
|
| 64 |
+
|
| 65 |
+
Implementations currently assume a discrete, one-dimensional action space per agent.
|
| 66 |
+
"""
|
| 67 |
raise NotImplementedError
|
| 68 |
|
| 69 |
def get_safe_copy(self):
|
| 70 |
"""
|
| 71 |
+
Return copy of the simulator that shares no mutable state with the original.
|
| 72 |
"""
|
| 73 |
raise NotImplementedError
|
| 74 |
|
| 75 |
def reset(self):
|
| 76 |
+
"""Reset to the initial state and return the starting observations."""
|
| 77 |
raise NotImplementedError
|
| 78 |
|
| 79 |
def render(self):
|
| 80 |
+
"""Optional human-facing visualization."""
|
| 81 |
raise NotImplementedError
|
| 82 |
|
| 83 |
def close(self):
|
| 84 |
+
"""Release any owned resources (files, processes, etc.)."""
|
| 85 |
raise NotImplementedError
|
| 86 |
|
| 87 |
# def seed(self):
|
src_code_for_reproducibility/markov_games/statistics_runner.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import gc
|
|
@@ -36,17 +41,20 @@ def _iterate_main_nodes(root: RolloutTreeRootNode) -> Iterator[RolloutTreeNode]:
|
|
| 36 |
def iterate_main_simulation_logs(
|
| 37 |
root: RolloutTreeRootNode,
|
| 38 |
) -> Iterator[SimulationStepLog]:
|
|
|
|
| 39 |
for node in _iterate_main_nodes(root):
|
| 40 |
yield node.step_log.simulation_step_log
|
| 41 |
|
| 42 |
|
| 43 |
def stream_rollout_files(iteration_folder: Path) -> Iterator[Path]:
|
|
|
|
| 44 |
for p in iteration_folder.rglob("*.rt.pkl"):
|
| 45 |
if p.is_file():
|
| 46 |
yield p
|
| 47 |
|
| 48 |
|
| 49 |
def load_root(path: Path) -> RolloutTreeRootNode:
|
|
|
|
| 50 |
with open(path, "rb") as f:
|
| 51 |
data = pickle.load(f)
|
| 52 |
return RolloutTreeRootNode.model_validate(data)
|
|
@@ -54,6 +62,8 @@ def load_root(path: Path) -> RolloutTreeRootNode:
|
|
| 54 |
|
| 55 |
@dataclass
|
| 56 |
class StatRecord:
|
|
|
|
|
|
|
| 57 |
mgid: int
|
| 58 |
crn_id: Optional[int]
|
| 59 |
iteration: str
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/markov_games/statistics_runner.py
|
| 3 |
+
Summary: Executes multiple rollouts to compute experiment statistics.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
from __future__ import annotations
|
| 7 |
|
| 8 |
import gc
|
|
|
|
| 41 |
def iterate_main_simulation_logs(
|
| 42 |
root: RolloutTreeRootNode,
|
| 43 |
) -> Iterator[SimulationStepLog]:
|
| 44 |
+
"""Yield ``SimulationStepLog`` objects along the main (non-branch) path."""
|
| 45 |
for node in _iterate_main_nodes(root):
|
| 46 |
yield node.step_log.simulation_step_log
|
| 47 |
|
| 48 |
|
| 49 |
def stream_rollout_files(iteration_folder: Path) -> Iterator[Path]:
|
| 50 |
+
"""Iterate over every ``*.rt.pkl`` file under an iteration directory."""
|
| 51 |
for p in iteration_folder.rglob("*.rt.pkl"):
|
| 52 |
if p.is_file():
|
| 53 |
yield p
|
| 54 |
|
| 55 |
|
| 56 |
def load_root(path: Path) -> RolloutTreeRootNode:
|
| 57 |
+
"""Load and validate a rollout tree from disk."""
|
| 58 |
with open(path, "rb") as f:
|
| 59 |
data = pickle.load(f)
|
| 60 |
return RolloutTreeRootNode.model_validate(data)
|
|
|
|
| 62 |
|
| 63 |
@dataclass
|
| 64 |
class StatRecord:
|
| 65 |
+
"""Convenience container for serialized stat rows."""
|
| 66 |
+
|
| 67 |
mgid: int
|
| 68 |
crn_id: Optional[int]
|
| 69 |
iteration: str
|
src_code_for_reproducibility/models/__init__.py
CHANGED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/models/__init__.py
|
| 3 |
+
Summary: Exports model-layer utilities from the models package.
|
| 4 |
+
"""
|
src_code_for_reproducibility/models/__pycache__/human_policy.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/models/__pycache__/human_policy.cpython-312.pyc and b/src_code_for_reproducibility/models/__pycache__/human_policy.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-312.pyc and b/src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc and b/src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-312.pyc and b/src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/models/__pycache__/large_language_model_api.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/models/__pycache__/large_language_model_api.cpython-312.pyc and b/src_code_for_reproducibility/models/__pycache__/large_language_model_api.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-312.pyc and b/src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc and b/src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/models/adapter_training_wrapper.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
| 3 |
import logging
|
| 4 |
from typing import Union
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
|
@@ -18,13 +21,14 @@ class AdapterWrapper(nn.Module):
|
|
| 18 |
• exposes only the parameters that should be trained for that adapter
|
| 19 |
(plus whatever extra modules you name).
|
| 20 |
"""
|
|
|
|
| 21 |
def __init__(
|
| 22 |
self,
|
| 23 |
shared_llm: nn.Module,
|
| 24 |
adapter_id: str,
|
| 25 |
lora_config: dict,
|
| 26 |
path: Union[str, None] = None,
|
| 27 |
-
|
| 28 |
super().__init__()
|
| 29 |
self.shared_llm = shared_llm
|
| 30 |
self.adapter_id = adapter_id
|
|
@@ -47,7 +51,9 @@ class AdapterWrapper(nn.Module):
|
|
| 47 |
adapter_name=adapter_id,
|
| 48 |
)
|
| 49 |
loaded_from = path
|
| 50 |
-
except
|
|
|
|
|
|
|
| 51 |
logger.warning(
|
| 52 |
f"Adapter '{adapter_id}': failed to load from '{path}': {exc}"
|
| 53 |
)
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/models/adapter_training_wrapper.py
|
| 3 |
+
Summary: Wraps a shared LLM with adapter-specific PEFT handling for training.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import logging
|
| 7 |
from typing import Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from peft import LoraConfig, get_peft_model
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
|
|
|
| 21 |
• exposes only the parameters that should be trained for that adapter
|
| 22 |
(plus whatever extra modules you name).
|
| 23 |
"""
|
| 24 |
+
|
| 25 |
def __init__(
|
| 26 |
self,
|
| 27 |
shared_llm: nn.Module,
|
| 28 |
adapter_id: str,
|
| 29 |
lora_config: dict,
|
| 30 |
path: Union[str, None] = None,
|
| 31 |
+
):
|
| 32 |
super().__init__()
|
| 33 |
self.shared_llm = shared_llm
|
| 34 |
self.adapter_id = adapter_id
|
|
|
|
| 51 |
adapter_name=adapter_id,
|
| 52 |
)
|
| 53 |
loaded_from = path
|
| 54 |
+
except (
|
| 55 |
+
Exception
|
| 56 |
+
) as exc: # noqa: BLE001 - want to log any load failure context
|
| 57 |
logger.warning(
|
| 58 |
f"Adapter '{adapter_id}': failed to load from '{path}': {exc}"
|
| 59 |
)
|
src_code_for_reproducibility/models/human_policy.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import os
|
| 3 |
import re
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/models/human_policy.py
|
| 3 |
+
Summary: Implements an interactive human-in-the-loop policy for experiments.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import asyncio
|
| 7 |
import os
|
| 8 |
import re
|
src_code_for_reproducibility/models/inference_backend.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from typing import Any, Optional
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/models/inference_backend.py
|
| 3 |
+
Summary: Declares the inference backend interface and shared dataclasses.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
from abc import ABC, abstractmethod
|
| 7 |
from dataclasses import dataclass
|
| 8 |
from typing import Any, Optional
|
src_code_for_reproducibility/models/inference_backend_dummy.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
from typing import Optional
|
| 3 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/models/inference_backend_dummy.py
|
| 3 |
+
Summary: Stub inference backend that returns synthetic completions for tests.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import asyncio
|
| 7 |
from typing import Optional
|
| 8 |
|
src_code_for_reproducibility/models/inference_backend_vllm.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import re
|
| 3 |
from typing import Optional
|
|
@@ -23,19 +28,12 @@ class VLLMAsyncBackend(LLMInferenceBackend):
|
|
| 23 |
sampling_params: dict = {},
|
| 24 |
):
|
| 25 |
self.model_name = model_name
|
| 26 |
-
# self.adapter_paths = adapter_paths or {}
|
| 27 |
-
# self.current_adapter = None
|
| 28 |
-
# self.vllm_adapter_ids = {
|
| 29 |
-
# adapter_id: generate_short_id() for adapter_id in adapter_paths.keys()
|
| 30 |
-
# }
|
| 31 |
self.vllm_adapter_ids = {}
|
| 32 |
ea = dict(model=model_name, **engine_init_kwargs)
|
| 33 |
-
# ea["enable_lora"] = True
|
| 34 |
-
# ea["max_loras"] = len(self.vllm_adapter_ids)
|
| 35 |
-
# ea["enable_sleep_mode"] = True
|
| 36 |
self.engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**ea))
|
| 37 |
|
| 38 |
self.sampling_params = sampling_params
|
|
|
|
| 39 |
|
| 40 |
def prepare_adapter(
|
| 41 |
self,
|
|
@@ -43,7 +41,6 @@ class VLLMAsyncBackend(LLMInferenceBackend):
|
|
| 43 |
adapter_path: Optional[str],
|
| 44 |
weights_got_updated: bool,
|
| 45 |
) -> None:
|
| 46 |
-
# self.current_adapter = adapter_id
|
| 47 |
if weights_got_updated:
|
| 48 |
self.vllm_adapter_ids[adapter_id] = generate_short_id()
|
| 49 |
self.current_lora_request = LoRARequest(
|
|
@@ -96,9 +93,6 @@ class VLLMAsyncBackend(LLMInferenceBackend):
|
|
| 96 |
]
|
| 97 |
log_probs = torch.tensor(log_probs)
|
| 98 |
out_token_ids = torch.tensor(out_token_ids, dtype=torch.long)
|
| 99 |
-
# for out_token_id, logprob_dict in zip(out_token_ids, res.outputs[0].logprobs):
|
| 100 |
-
# if logprob_dict[out_token_id].logprob < -1:
|
| 101 |
-
# print(f"High negative logprob {logprob_dict[out_token_id].logprob} for {logprob_dict}")
|
| 102 |
content = raw_text
|
| 103 |
reasoning_content = None
|
| 104 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/models/inference_backend_vllm.py
|
| 3 |
+
Summary: Connects to in-process vLLM instances for batched generation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import asyncio
|
| 7 |
import re
|
| 8 |
from typing import Optional
|
|
|
|
| 28 |
sampling_params: dict = {},
|
| 29 |
):
|
| 30 |
self.model_name = model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
self.vllm_adapter_ids = {}
|
| 32 |
ea = dict(model=model_name, **engine_init_kwargs)
|
|
|
|
|
|
|
|
|
|
| 33 |
self.engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**ea))
|
| 34 |
|
| 35 |
self.sampling_params = sampling_params
|
| 36 |
+
self.tokenizer = tokenizer
|
| 37 |
|
| 38 |
def prepare_adapter(
|
| 39 |
self,
|
|
|
|
| 41 |
adapter_path: Optional[str],
|
| 42 |
weights_got_updated: bool,
|
| 43 |
) -> None:
|
|
|
|
| 44 |
if weights_got_updated:
|
| 45 |
self.vllm_adapter_ids[adapter_id] = generate_short_id()
|
| 46 |
self.current_lora_request = LoRARequest(
|
|
|
|
| 93 |
]
|
| 94 |
log_probs = torch.tensor(log_probs)
|
| 95 |
out_token_ids = torch.tensor(out_token_ids, dtype=torch.long)
|
|
|
|
|
|
|
|
|
|
| 96 |
content = raw_text
|
| 97 |
reasoning_content = None
|
| 98 |
|
src_code_for_reproducibility/models/large_language_model_api.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import asyncio
|
|
@@ -13,7 +18,7 @@ from openai import AsyncOpenAI, OpenAIError
|
|
| 13 |
from mllm.markov_games.rollout_tree import ChatTurn
|
| 14 |
from mllm.models.inference_backend import LLMInferenceOutput
|
| 15 |
|
| 16 |
-
#
|
| 17 |
reasoning_models = [
|
| 18 |
"gpt-5-nano",
|
| 19 |
"gpt-5-mini",
|
|
@@ -119,9 +124,7 @@ class LargeLanguageModelOpenAI:
|
|
| 119 |
agent_id: str,
|
| 120 |
regex: Optional[str] = None,
|
| 121 |
) -> LLMInferenceOutput:
|
| 122 |
-
# Remove any non-role/content keys from the prompt else openai will error
|
| 123 |
-
|
| 124 |
-
# TODO:
|
| 125 |
prompt = [{"role": p.role, "content": p.content} for p in state]
|
| 126 |
|
| 127 |
# if self.sleep_between_requests:
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/models/large_language_model_api.py
|
| 3 |
+
Summary: Implements API-based large-language-model inference adapters.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
from __future__ import annotations
|
| 7 |
|
| 8 |
import asyncio
|
|
|
|
| 18 |
from mllm.markov_games.rollout_tree import ChatTurn
|
| 19 |
from mllm.models.inference_backend import LLMInferenceOutput
|
| 20 |
|
| 21 |
+
# Static list copied from the public OpenAI docs until a discovery endpoint is exposed.
|
| 22 |
reasoning_models = [
|
| 23 |
"gpt-5-nano",
|
| 24 |
"gpt-5-mini",
|
|
|
|
| 124 |
agent_id: str,
|
| 125 |
regex: Optional[str] = None,
|
| 126 |
) -> LLMInferenceOutput:
|
| 127 |
+
# Remove any non-role/content keys from the prompt else openai will error.
|
|
|
|
|
|
|
| 128 |
prompt = [{"role": p.role, "content": p.content} for p in state]
|
| 129 |
|
| 130 |
# if self.sleep_between_requests:
|
src_code_for_reproducibility/models/large_language_model_local.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
-
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import logging
|
|
@@ -16,23 +17,14 @@ import httpx
|
|
| 16 |
import requests
|
| 17 |
import torch
|
| 18 |
import torch.nn as nn
|
| 19 |
-
|
| 20 |
-
# from sglang.utils import (
|
| 21 |
-
# launch_server_cmd,
|
| 22 |
-
# print_highlight,
|
| 23 |
-
# terminate_process,
|
| 24 |
-
# wait_for_server,
|
| 25 |
-
# )
|
| 26 |
from torch.optim import SGD, Adam, AdamW, RMSprop
|
| 27 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 28 |
-
from trl import AutoModelForCausalLMWithValueHead
|
| 29 |
|
| 30 |
from mllm.chat_utils.apply_template import chat_turns_to_token_ids
|
| 31 |
from mllm.markov_games.rollout_tree import ChatTurn
|
| 32 |
from mllm.models.adapter_training_wrapper import AdapterWrapper
|
| 33 |
from mllm.models.inference_backend import LLMInferenceOutput
|
| 34 |
from mllm.models.inference_backend_dummy import DummyInferenceBackend
|
| 35 |
-
from mllm.models.inference_backend_sglang import SGLangOfflineBackend
|
| 36 |
from mllm.models.inference_backend_vllm import VLLMAsyncBackend
|
| 37 |
|
| 38 |
logger = logging.getLogger(__name__)
|
|
@@ -44,7 +36,7 @@ PolicyID = str
|
|
| 44 |
|
| 45 |
class LeanLocalLLM:
|
| 46 |
"""
|
| 47 |
-
|
| 48 |
"""
|
| 49 |
|
| 50 |
def __init__(
|
|
@@ -55,7 +47,7 @@ class LeanLocalLLM:
|
|
| 55 |
hf_kwargs: dict = {},
|
| 56 |
adapter_configs: dict = {},
|
| 57 |
output_directory: str = "./models/",
|
| 58 |
-
inference_backend: Literal["vllm", "
|
| 59 |
inference_backend_sampling_params: dict = {},
|
| 60 |
inference_backend_init_kwargs: dict = {},
|
| 61 |
initial_adapter_paths: dict[str, str] | None = None,
|
|
@@ -180,15 +172,7 @@ class LeanLocalLLM:
|
|
| 180 |
# Init inference inference_backend
|
| 181 |
# ---------------------------------------------------------
|
| 182 |
|
| 183 |
-
if inference_backend == "
|
| 184 |
-
self.inference_backend = SGLangOfflineBackend(
|
| 185 |
-
model_name=self.model_name,
|
| 186 |
-
save_path=self.save_path,
|
| 187 |
-
adapter_paths=self.adapter_paths,
|
| 188 |
-
tokenizer=self.tokenizer,
|
| 189 |
-
kwargs=inference_backend_init_kwargs,
|
| 190 |
-
)
|
| 191 |
-
elif inference_backend == "vllm":
|
| 192 |
self.inference_backend = VLLMAsyncBackend(
|
| 193 |
model_name=self.model_name,
|
| 194 |
# adapter_paths=self.adapter_paths,
|
|
@@ -206,7 +190,7 @@ class LeanLocalLLM:
|
|
| 206 |
|
| 207 |
def get_inference_policies(self) -> dict[PolicyID, Callable]:
|
| 208 |
"""
|
| 209 |
-
|
| 210 |
"""
|
| 211 |
policies = {}
|
| 212 |
for adapter_id in self.adapter_ids:
|
|
@@ -242,8 +226,8 @@ class LeanLocalLLM:
|
|
| 242 |
"""
|
| 243 |
Returns wrappers over the adapters which allows them be
|
| 244 |
interfaced like regular PyTorch models.
|
| 245 |
-
|
| 246 |
-
|
| 247 |
"""
|
| 248 |
trainable_objects = {an: self.hf_adapters[an] for an in self.adapter_ids}
|
| 249 |
return trainable_objects
|
|
@@ -297,13 +281,11 @@ class LeanLocalLLM:
|
|
| 297 |
tokenizer=self.tokenizer,
|
| 298 |
enable_thinking=self.enable_thinking,
|
| 299 |
)
|
| 300 |
-
# print(f"context is {self.tokenizer.decode(context_token_ids)}")
|
| 301 |
policy_output = await self.inference_backend.generate(
|
| 302 |
input_token_ids=context_token_ids.tolist(),
|
| 303 |
extract_thinking=(self.max_thinking_characters > 0),
|
| 304 |
regex=current_regex,
|
| 305 |
)
|
| 306 |
-
# print(f"generated: {self.tokenizer.decode(policy_output.out_token_ids)}")
|
| 307 |
if (
|
| 308 |
pattern is None
|
| 309 |
or (pattern.fullmatch(policy_output.content))
|
|
@@ -347,11 +329,6 @@ class LeanLocalLLM:
|
|
| 347 |
for adapter_id in self.past_agent_adapter_ids:
|
| 348 |
self.weights_got_updated[adapter_id] = True
|
| 349 |
|
| 350 |
-
# import random
|
| 351 |
-
# self.save_path = self.save_path + str(random.randint(1,500))
|
| 352 |
-
# print(f"Save path: {self.save_path}")
|
| 353 |
-
# self.adapter_paths = {adapter_id:os.path.join(self.save_path, adapter_id) for adapter_id in self.adapter_ids}
|
| 354 |
-
|
| 355 |
adapter_id = self.adapter_ids[0]
|
| 356 |
self.hf_adapters[adapter_id].save_pretrained(self.save_path)
|
| 357 |
|
|
|
|
| 1 |
"""
|
| 2 |
+
File: mllm/models/large_language_model_local.py
|
| 3 |
+
Summary: Provides a local large language model wrapper over inference backends.
|
| 4 |
"""
|
| 5 |
|
| 6 |
import logging
|
|
|
|
| 17 |
import requests
|
| 18 |
import torch
|
| 19 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
from torch.optim import SGD, Adam, AdamW, RMSprop
|
| 21 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
| 22 |
|
| 23 |
from mllm.chat_utils.apply_template import chat_turns_to_token_ids
|
| 24 |
from mllm.markov_games.rollout_tree import ChatTurn
|
| 25 |
from mllm.models.adapter_training_wrapper import AdapterWrapper
|
| 26 |
from mllm.models.inference_backend import LLMInferenceOutput
|
| 27 |
from mllm.models.inference_backend_dummy import DummyInferenceBackend
|
|
|
|
| 28 |
from mllm.models.inference_backend_vllm import VLLMAsyncBackend
|
| 29 |
|
| 30 |
logger = logging.getLogger(__name__)
|
|
|
|
| 36 |
|
| 37 |
class LeanLocalLLM:
|
| 38 |
"""
|
| 39 |
+
Wrapper that manages local HuggingFace models, adapters, and inference backends.
|
| 40 |
"""
|
| 41 |
|
| 42 |
def __init__(
|
|
|
|
| 47 |
hf_kwargs: dict = {},
|
| 48 |
adapter_configs: dict = {},
|
| 49 |
output_directory: str = "./models/",
|
| 50 |
+
inference_backend: Literal["vllm", "dummy"] = "vllm",
|
| 51 |
inference_backend_sampling_params: dict = {},
|
| 52 |
inference_backend_init_kwargs: dict = {},
|
| 53 |
initial_adapter_paths: dict[str, str] | None = None,
|
|
|
|
| 172 |
# Init inference inference_backend
|
| 173 |
# ---------------------------------------------------------
|
| 174 |
|
| 175 |
+
if inference_backend == "vllm":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
self.inference_backend = VLLMAsyncBackend(
|
| 177 |
model_name=self.model_name,
|
| 178 |
# adapter_paths=self.adapter_paths,
|
|
|
|
| 190 |
|
| 191 |
def get_inference_policies(self) -> dict[PolicyID, Callable]:
|
| 192 |
"""
|
| 193 |
+
Build async policy callables keyed by adapter id for inference-only usage.
|
| 194 |
"""
|
| 195 |
policies = {}
|
| 196 |
for adapter_id in self.adapter_ids:
|
|
|
|
| 226 |
"""
|
| 227 |
Returns wrappers over the adapters which allows them be
|
| 228 |
interfaced like regular PyTorch models.
|
| 229 |
+
AdapterWrapper lives in adapter_wrapper.py; the huggingface modules already wrap
|
| 230 |
+
parameters here, so we surface them directly until an extra shim is required.
|
| 231 |
"""
|
| 232 |
trainable_objects = {an: self.hf_adapters[an] for an in self.adapter_ids}
|
| 233 |
return trainable_objects
|
|
|
|
| 281 |
tokenizer=self.tokenizer,
|
| 282 |
enable_thinking=self.enable_thinking,
|
| 283 |
)
|
|
|
|
| 284 |
policy_output = await self.inference_backend.generate(
|
| 285 |
input_token_ids=context_token_ids.tolist(),
|
| 286 |
extract_thinking=(self.max_thinking_characters > 0),
|
| 287 |
regex=current_regex,
|
| 288 |
)
|
|
|
|
| 289 |
if (
|
| 290 |
pattern is None
|
| 291 |
or (pattern.fullmatch(policy_output.content))
|
|
|
|
| 329 |
for adapter_id in self.past_agent_adapter_ids:
|
| 330 |
self.weights_got_updated[adapter_id] = True
|
| 331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
adapter_id = self.adapter_ids[0]
|
| 333 |
self.hf_adapters[adapter_id].save_pretrained(self.save_path)
|
| 334 |
|
src_code_for_reproducibility/models/scalar_critic.py
CHANGED
|
@@ -1,6 +1,13 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from peft import LoraConfig, get_peft_model
|
|
|
|
| 4 |
|
| 5 |
from mllm.models.adapter_training_wrapper import AdapterWrapper
|
| 6 |
|
|
@@ -11,18 +18,16 @@ class ScalarCritic(nn.Module):
|
|
| 11 |
V_φ(s) = wᵀ h_last + b
|
| 12 |
Only LoRA adapters (inside critic_adapter) and the value head are trainable.
|
| 13 |
"""
|
|
|
|
| 14 |
def __init__(self, critic_adapter: AdapterWrapper):
|
| 15 |
super().__init__()
|
| 16 |
self.critic_adapter = critic_adapter
|
| 17 |
hidden_size = self.critic_adapter.shared_llm.config.hidden_size
|
| 18 |
self.value_head = nn.Linear(hidden_size, 1).to(
|
| 19 |
-
dtype=critic_adapter.dtype,
|
| 20 |
-
|
| 21 |
|
| 22 |
-
def forward(self,
|
| 23 |
-
input_ids,
|
| 24 |
-
attention_mask=None,
|
| 25 |
-
**kwargs):
|
| 26 |
# AdapterWrapper activates its own adapter internally
|
| 27 |
outputs = self.critic_adapter(
|
| 28 |
input_ids=input_ids,
|
|
@@ -30,7 +35,7 @@ class ScalarCritic(nn.Module):
|
|
| 30 |
output_hidden_states=True,
|
| 31 |
**kwargs,
|
| 32 |
)
|
| 33 |
-
h_last = outputs.hidden_states[-1]
|
| 34 |
values = self.value_head(h_last).squeeze(-1) # (B, S)
|
| 35 |
return values
|
| 36 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/models/scalar_critic.py
|
| 3 |
+
Summary: Defines a scalar critic network and helper utilities.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.optim as optim
|
| 9 |
from peft import LoraConfig, get_peft_model
|
| 10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 11 |
|
| 12 |
from mllm.models.adapter_training_wrapper import AdapterWrapper
|
| 13 |
|
|
|
|
| 18 |
V_φ(s) = wᵀ h_last + b
|
| 19 |
Only LoRA adapters (inside critic_adapter) and the value head are trainable.
|
| 20 |
"""
|
| 21 |
+
|
| 22 |
def __init__(self, critic_adapter: AdapterWrapper):
|
| 23 |
super().__init__()
|
| 24 |
self.critic_adapter = critic_adapter
|
| 25 |
hidden_size = self.critic_adapter.shared_llm.config.hidden_size
|
| 26 |
self.value_head = nn.Linear(hidden_size, 1).to(
|
| 27 |
+
dtype=critic_adapter.dtype, device=critic_adapter.device
|
| 28 |
+
)
|
| 29 |
|
| 30 |
+
def forward(self, input_ids, attention_mask=None, **kwargs):
|
|
|
|
|
|
|
|
|
|
| 31 |
# AdapterWrapper activates its own adapter internally
|
| 32 |
outputs = self.critic_adapter(
|
| 33 |
input_ids=input_ids,
|
|
|
|
| 35 |
output_hidden_states=True,
|
| 36 |
**kwargs,
|
| 37 |
)
|
| 38 |
+
h_last = outputs.hidden_states[-1] # (B, S, H)
|
| 39 |
values = self.value_head(h_last).squeeze(-1) # (B, S)
|
| 40 |
return values
|
| 41 |
|
src_code_for_reproducibility/training/__init__.py
CHANGED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/training/__init__.py
|
| 3 |
+
Summary: Exposes training submodules through the package namespace.
|
| 4 |
+
"""
|
src_code_for_reproducibility/training/__pycache__/__init__.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/training/__pycache__/__init__.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/__init__.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/training/__pycache__/annealing_methods.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/training/__pycache__/annealing_methods.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/annealing_methods.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/training/__pycache__/credit_methods.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/training/__pycache__/credit_methods.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/credit_methods.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/training/__pycache__/tally_metrics.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/training/__pycache__/tally_metrics.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/tally_metrics.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/training/__pycache__/tally_rollout.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/training/__pycache__/tally_rollout.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/tally_rollout.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/training/__pycache__/tally_tokenwise.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/training/__pycache__/tally_tokenwise.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/tally_tokenwise.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/training/__pycache__/tokenize_chats.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/training/__pycache__/tokenize_chats.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/tokenize_chats.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/training/__pycache__/trainer_ad_align.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/training/__pycache__/trainer_ad_align.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/trainer_ad_align.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/training/__pycache__/trainer_common.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/training/__pycache__/trainer_common.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/trainer_common.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/training/__pycache__/trainer_independent.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/training/__pycache__/trainer_independent.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/trainer_independent.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/training/__pycache__/trainer_sum_rewards.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/training/__pycache__/trainer_sum_rewards.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/trainer_sum_rewards.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/training/__pycache__/training_data_utils.cpython-312.pyc
CHANGED
|
Binary files a/src_code_for_reproducibility/training/__pycache__/training_data_utils.cpython-312.pyc and b/src_code_for_reproducibility/training/__pycache__/training_data_utils.cpython-312.pyc differ
|
|
|
src_code_for_reproducibility/training/annealing_methods.py
CHANGED
|
@@ -1,6 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
|
| 3 |
|
| 4 |
def sigmoid_annealing(step: int, temperature: float) -> float:
|
| 5 |
-
|
|
|
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/training/annealing_methods.py
|
| 3 |
+
Summary: Implements annealing schedules used across training loops.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import numpy as np
|
| 7 |
|
| 8 |
|
| 9 |
def sigmoid_annealing(step: int, temperature: float) -> float:
|
| 10 |
+
"""
|
| 11 |
+
Smoothly ramp a scalar from 0 → 1 using a temperature-controlled sigmoid.
|
| 12 |
|
| 13 |
+
Args:
|
| 14 |
+
step: Current training step or iteration.
|
| 15 |
+
temperature: Controls how sharp the transition is; larger values flatten the curve.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
Float in [-1, 1] that can be rescaled for annealing schedules.
|
| 19 |
+
"""
|
| 20 |
+
return 2 / (1 + np.exp(-step / temperature)) - 1
|