Add files using upload-large-folder tool
Browse files- seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_model.safetensors +3 -0
- seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_model.safetensors +3 -0
- seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/fixed_ad_align_adapter/adapter_model.safetensors +3 -0
- seed_0/agent_trainer/policy_optimizer_state.pt +3 -0
- seed_0/agent_trainer/trainer_annealing_state.pkl +3 -0
- seed_0/random_state.pkl +3 -0
- src_code_for_reproducibility/markov_games/__pycache__/alternative_actions_runner.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/__pycache__/markov_game.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/negotiation/dond_agent.py +75 -0
- src_code_for_reproducibility/markov_games/negotiation/nego_simulation.py +252 -0
- src_code_for_reproducibility/models/__pycache__/__init__.cpython-312.pyc +0 -0
- src_code_for_reproducibility/models/__pycache__/adapter_training_wrapper.cpython-312.pyc +0 -0
- src_code_for_reproducibility/models/__pycache__/large_language_model_gemini_api.cpython-312.pyc +0 -0
- src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-312.pyc +0 -0
- src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc +0 -0
- src_code_for_reproducibility/training/__pycache__/tokenize_chats.cpython-312.pyc +0 -0
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6add30df6b66776172322b39e7314659ebdc01e393a2c23c6be659ab5fcbeffd
|
| 3 |
+
size 323014168
|
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:50cfa136e5499e5b1f83c90753b519572d60a378c94d09953a2738af6a8ae3c1
|
| 3 |
+
size 323014168
|
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/fixed_ad_align_adapter/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d7c1605ae0836578b011534ca9f02f01ab903bb99c9d3acd229f702d1613c046
|
| 3 |
+
size 323014168
|
seed_0/agent_trainer/policy_optimizer_state.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6c61bf98d3328b3ed76ef4d2496e7e6ac114f54b9b7b71d75265e41ac95a8195
|
| 3 |
+
size 646269121
|
seed_0/agent_trainer/trainer_annealing_state.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:09bcf2bd05ac3d675df0a5420216edac0eb8e58b84a53ee812fa567ccb0476cb
|
| 3 |
+
size 104
|
seed_0/random_state.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0fc49c859cd303ac116afc9699963ddc86f25a3aa08f9722fdb15bdb35c642dd
|
| 3 |
+
size 12176
|
src_code_for_reproducibility/markov_games/__pycache__/alternative_actions_runner.cpython-312.pyc
ADDED
|
Binary file (5.42 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/__pycache__/markov_game.cpython-312.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/negotiation/dond_agent.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/markov_games/negotiation/dond_agent.py
|
| 3 |
+
Summary: Agent implementation for Deal-or-No-Deal style negotiations.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
import re
|
| 8 |
+
from collections.abc import Callable
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Any, Dict, List, Tuple
|
| 11 |
+
|
| 12 |
+
from mllm.markov_games.agent import Agent
|
| 13 |
+
from mllm.markov_games.negotiation.dond_simulation import DealNoDealObs
|
| 14 |
+
from mllm.markov_games.negotiation.nego_agent import (
|
| 15 |
+
NegotiationAgent,
|
| 16 |
+
NegotiationAgentState,
|
| 17 |
+
)
|
| 18 |
+
from mllm.markov_games.negotiation.nego_simulation import Split
|
| 19 |
+
from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DealNoDealAgent(NegotiationAgent):
|
| 23 |
+
"""NegotiationAgent tailored to the Deal-or-No-Deal stock/value revelation rules."""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
*args,
|
| 28 |
+
**kwargs,
|
| 29 |
+
):
|
| 30 |
+
super().__init__(*args, **kwargs)
|
| 31 |
+
self.intro_prompt = (
|
| 32 |
+
"You are {agent_id}. You are playing an iterated game. "
|
| 33 |
+
"At each round, you and other agent will try to distribute among yourselves items of types {item_types}. "
|
| 34 |
+
"You only know how much you value each item type, but not the other agent's values. "
|
| 35 |
+
"You can communicate with the other agent by sending up to {quota_messages_per_agent_per_round} short messages per round. "
|
| 36 |
+
"Each round, after exchanging messages, you and the other agent will submit a private proposal. "
|
| 37 |
+
"A deal is accepted only if both proposals match exactly and are within stock; otherwise no deal (0 points for both at that round). "
|
| 38 |
+
"The values of the items of the other agent at the previous round are revealed to you after each round. "
|
| 39 |
+
"Your goal is: {goal}."
|
| 40 |
+
)
|
| 41 |
+
self.new_round_prompt = (
|
| 42 |
+
"New round {round_nb}. Items: {stock}. Your values: {values}. "
|
| 43 |
+
)
|
| 44 |
+
self.last_round_prompt = (
|
| 45 |
+
"Last round, other agent's values: {previous_values_coagent}. "
|
| 46 |
+
)
|
| 47 |
+
self.send_split_prompt = "Respond with <split>...</split> where you propose how many items of each type you want to keep."
|
| 48 |
+
|
| 49 |
+
def get_message_regex(self, observation: DealNoDealObs) -> str:
|
| 50 |
+
"""Allow short XML messages (<400 chars) between proposal phases."""
|
| 51 |
+
return r"<message>[\s\S]{0,400}</message>"
|
| 52 |
+
|
| 53 |
+
def get_split_regex(self, observation: DealNoDealObs) -> str:
|
| 54 |
+
"""Constrain split proposals to per-item XML tags bounded by the current stock."""
|
| 55 |
+
parts = []
|
| 56 |
+
for t in observation.item_types:
|
| 57 |
+
s = int(observation.quantities.get(t, 0))
|
| 58 |
+
allowed = "|".join(str(k) for k in range(0, s + 1))
|
| 59 |
+
rng = f"({allowed})"
|
| 60 |
+
parts.append(rf"<{t}>{rng}</{t}>")
|
| 61 |
+
items_block = "".join(parts)
|
| 62 |
+
return rf"(<split>{items_block}</split>)"
|
| 63 |
+
|
| 64 |
+
def get_split_action(self, policy_output: str, observation: DealNoDealObs) -> Split:
|
| 65 |
+
"""Convert the XML proposal into a Split dataclass understood by the simulator."""
|
| 66 |
+
import re as _re
|
| 67 |
+
|
| 68 |
+
allocations: Dict[str, int] = {}
|
| 69 |
+
for t in observation.item_types:
|
| 70 |
+
m = _re.search(rf"<{t}>([0-9]+)</{t}>", policy_output)
|
| 71 |
+
if m:
|
| 72 |
+
allocations[t] = int(m.group(1))
|
| 73 |
+
else:
|
| 74 |
+
allocations[t] = 0
|
| 75 |
+
return Split(items_given_to_self=allocations)
|
src_code_for_reproducibility/markov_games/negotiation/nego_simulation.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/markov_games/negotiation/nego_simulation.py
|
| 3 |
+
Summary: Simulation harness for general negotiation environments.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
from abc import abstractmethod
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Any, Dict, List, Tuple
|
| 10 |
+
|
| 11 |
+
from numpy.random import default_rng
|
| 12 |
+
|
| 13 |
+
from mllm.markov_games.rollout_tree import SimulationStepLog
|
| 14 |
+
from mllm.markov_games.simulation import Simulation
|
| 15 |
+
from mllm.utils.get_coagent_id import get_coagent_id
|
| 16 |
+
|
| 17 |
+
AgentId = str
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class Split:
|
| 22 |
+
"""Structured proposal describing how many units of each item an agent keeps."""
|
| 23 |
+
|
| 24 |
+
items_given_to_self: Dict[str, int]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class Message:
|
| 29 |
+
"""Single chat utterance exchanged during the negotiation phase."""
|
| 30 |
+
|
| 31 |
+
message: str
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass # gets extended by variants
|
| 35 |
+
class NegotiationState:
|
| 36 |
+
"""Full simulator state snapshot shared by all negotiation variants."""
|
| 37 |
+
|
| 38 |
+
round_nb: int
|
| 39 |
+
last_message: str
|
| 40 |
+
current_agent: AgentId
|
| 41 |
+
quantities: Dict[str, int]
|
| 42 |
+
values: Dict[AgentId, Dict[str, float]]
|
| 43 |
+
splits: Dict[AgentId, Split | None]
|
| 44 |
+
nb_messages_sent: Dict[AgentId, int]
|
| 45 |
+
previous_values: Dict[AgentId, Dict[str, float]] | None
|
| 46 |
+
previous_splits: Dict[AgentId, Dict[str, int] | None] | None
|
| 47 |
+
previous_points: Dict[AgentId, float] | None
|
| 48 |
+
previous_quantities: Dict[str, int] | None
|
| 49 |
+
split_phase: bool
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass # gets extended by variants
|
| 53 |
+
class NegotiationObs:
|
| 54 |
+
"""Observation presented to agents each turn (base fields; variants extend)."""
|
| 55 |
+
|
| 56 |
+
round_nb: int
|
| 57 |
+
last_message: str
|
| 58 |
+
quota_messages_per_agent_per_round: int
|
| 59 |
+
current_agent: AgentId
|
| 60 |
+
other_agent: str
|
| 61 |
+
quantities: Dict[str, int]
|
| 62 |
+
item_types: List[str]
|
| 63 |
+
value: Dict[str, int]
|
| 64 |
+
split_phase: bool
|
| 65 |
+
last_split_agent: Dict[str, int] | None
|
| 66 |
+
last_value_agent: Dict[str, int] | None
|
| 67 |
+
last_points_agent: float | None
|
| 68 |
+
last_split_coagent: Dict[str, int] | None
|
| 69 |
+
last_value_coagent: Dict[str, int] | None
|
| 70 |
+
last_points_coagent: float | None
|
| 71 |
+
last_quantities: Dict[str, int] | None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def compute_tas_style_rewards(
|
| 75 |
+
agent_ids: List[AgentId],
|
| 76 |
+
values: Dict[AgentId, float],
|
| 77 |
+
splits: Dict[AgentId, Split],
|
| 78 |
+
quantities: Dict[str, int],
|
| 79 |
+
) -> Dict[AgentId, float]:
|
| 80 |
+
"""
|
| 81 |
+
TAS-like reward computation: if sum of proposed coins exceeds max_coins,
|
| 82 |
+
allocate proportionally. Otherwise, use proposed amounts directly.
|
| 83 |
+
Rewards are quantity_kept * per-coin value for each agent.
|
| 84 |
+
"""
|
| 85 |
+
a0, a1 = agent_ids[0], agent_ids[1]
|
| 86 |
+
r0, r1 = 0.0, 0.0
|
| 87 |
+
|
| 88 |
+
for item in quantities:
|
| 89 |
+
max_item = quantities[item]
|
| 90 |
+
item_to_self_0 = int(
|
| 91 |
+
(splits[a0].items_given_to_self.get(item, 0))
|
| 92 |
+
if splits[a0] is not None
|
| 93 |
+
else 0
|
| 94 |
+
)
|
| 95 |
+
item_to_self_1 = int(
|
| 96 |
+
(splits[a1].items_given_to_self.get(item, 0))
|
| 97 |
+
if splits[a1] is not None
|
| 98 |
+
else 0
|
| 99 |
+
)
|
| 100 |
+
denom = max(int(max_item), item_to_self_0 + item_to_self_1)
|
| 101 |
+
q0 = float(max_item) * float(item_to_self_0) / float(denom)
|
| 102 |
+
q1 = float(max_item) * float(item_to_self_1) / float(denom)
|
| 103 |
+
if type(values[a0]) is not dict:
|
| 104 |
+
r0 += q0 * float(values[a0])
|
| 105 |
+
r1 += q1 * float(values[a1])
|
| 106 |
+
else:
|
| 107 |
+
r0 += q0 * float(values[a0][item])
|
| 108 |
+
r1 += q1 * float(values[a1][item])
|
| 109 |
+
return {a0: r0, a1: r1}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class NegotiationSimulation(Simulation):
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
agent_ids: List[AgentId],
|
| 116 |
+
agent_names: List[str],
|
| 117 |
+
seed: int,
|
| 118 |
+
nb_of_rounds: int,
|
| 119 |
+
quota_messages_per_agent_per_round: int,
|
| 120 |
+
item_types: List[str] | None = None,
|
| 121 |
+
):
|
| 122 |
+
self.seed = seed
|
| 123 |
+
self.rng = default_rng(self.seed)
|
| 124 |
+
self.agent_ids = list(agent_ids)
|
| 125 |
+
self.agent_names = agent_names
|
| 126 |
+
self.agent_id_to_name = {
|
| 127 |
+
agent_id: agent_name for agent_id, agent_name in zip(agent_ids, agent_names)
|
| 128 |
+
}
|
| 129 |
+
self.nb_of_rounds = int(nb_of_rounds)
|
| 130 |
+
self.quota_messages_per_agent_per_round = int(
|
| 131 |
+
quota_messages_per_agent_per_round
|
| 132 |
+
)
|
| 133 |
+
if item_types is not None:
|
| 134 |
+
self.item_types = [item.lower() for item in item_types]
|
| 135 |
+
else:
|
| 136 |
+
self.item_types = ["coins"]
|
| 137 |
+
self.state: NegotiationState | None = None
|
| 138 |
+
self._starting_agent_index = self.rng.choice([0, 1])
|
| 139 |
+
self.reset()
|
| 140 |
+
|
| 141 |
+
def _other(self, agent_id: AgentId) -> AgentId:
|
| 142 |
+
return get_coagent_id(self.agent_ids, agent_id)
|
| 143 |
+
|
| 144 |
+
@abstractmethod
|
| 145 |
+
def set_new_round_of_variant(self):
|
| 146 |
+
"""Variant hook: sample new private values / stock before each round."""
|
| 147 |
+
pass
|
| 148 |
+
|
| 149 |
+
@abstractmethod
|
| 150 |
+
def get_info_of_variant(
|
| 151 |
+
self, state: NegotiationState, actions: Dict[AgentId, Any]
|
| 152 |
+
) -> Dict[str, Any]:
|
| 153 |
+
"""Variant hook: populate SimulationStepLog.info with custom diagnostics."""
|
| 154 |
+
pass
|
| 155 |
+
|
| 156 |
+
def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]:
|
| 157 |
+
"""
|
| 158 |
+
Returns terminated, step_log
|
| 159 |
+
"""
|
| 160 |
+
assert self.state is not None
|
| 161 |
+
current_agent = self.state.current_agent
|
| 162 |
+
a0, a1 = self.agent_ids[0], self.agent_ids[1]
|
| 163 |
+
action = actions.get(current_agent)
|
| 164 |
+
|
| 165 |
+
# Split phase: require both splits in the same timestep
|
| 166 |
+
if self.state.split_phase:
|
| 167 |
+
action_a0 = actions.get(a0)
|
| 168 |
+
action_a1 = actions.get(a1)
|
| 169 |
+
have_both_splits = isinstance(action_a0, Split) and isinstance(
|
| 170 |
+
action_a1, Split
|
| 171 |
+
)
|
| 172 |
+
if not have_both_splits:
|
| 173 |
+
rewards = {agent_id: 0.0 for agent_id in self.agent_ids}
|
| 174 |
+
return False, SimulationStepLog(
|
| 175 |
+
rewards=rewards, info={"type": "waiting_for_splits"}
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Record splits
|
| 179 |
+
self.state.splits[a0] = action_a0
|
| 180 |
+
self.state.splits[a1] = action_a1
|
| 181 |
+
|
| 182 |
+
# Compute rewards and end round
|
| 183 |
+
rewards = self.get_rewards(self.state.splits)
|
| 184 |
+
|
| 185 |
+
# Info
|
| 186 |
+
info = self.get_info_of_variant(self.state, actions)
|
| 187 |
+
|
| 188 |
+
# Prepare next round
|
| 189 |
+
# Alternate starting agent
|
| 190 |
+
self.state.round_nb += 1
|
| 191 |
+
self._starting_agent_index = 1 - self._starting_agent_index
|
| 192 |
+
self.state.current_agent = self.agent_ids[self._starting_agent_index]
|
| 193 |
+
self.state.previous_values = copy.deepcopy(self.state.values)
|
| 194 |
+
self.state.previous_splits = copy.deepcopy(self.state.splits)
|
| 195 |
+
self.state.previous_quantities = copy.deepcopy(self.state.quantities)
|
| 196 |
+
self.state.previous_points = copy.deepcopy(rewards)
|
| 197 |
+
self.state.last_message = ""
|
| 198 |
+
self.set_new_round_of_variant() # variant specific
|
| 199 |
+
self.state.splits = {agent_id: None for agent_id in self.agent_ids}
|
| 200 |
+
self.state.nb_messages_sent = {agent_id: 0 for agent_id in self.agent_ids}
|
| 201 |
+
is_last_timestep_in_round = True
|
| 202 |
+
done = self.state.round_nb >= self.nb_of_rounds
|
| 203 |
+
|
| 204 |
+
# Message phase: roll the conversation forward a single turn.
|
| 205 |
+
elif isinstance(action, Message):
|
| 206 |
+
self.state.last_message = action.message
|
| 207 |
+
self.state.nb_messages_sent[current_agent] += 1
|
| 208 |
+
|
| 209 |
+
# Move turn to other agent
|
| 210 |
+
self.state.current_agent = self._other(current_agent)
|
| 211 |
+
|
| 212 |
+
# If both agents have reached their message quota, enter split phase
|
| 213 |
+
if all(
|
| 214 |
+
self.state.nb_messages_sent[agent_id]
|
| 215 |
+
>= self.quota_messages_per_agent_per_round
|
| 216 |
+
for agent_id in self.agent_ids
|
| 217 |
+
):
|
| 218 |
+
self.state.split_phase = True
|
| 219 |
+
is_last_timestep_in_round = False
|
| 220 |
+
done = False
|
| 221 |
+
rewards = {agent_id: 0.0 for agent_id in self.agent_ids}
|
| 222 |
+
info = {"type": "message"}
|
| 223 |
+
|
| 224 |
+
info[
|
| 225 |
+
"is_last_timestep_in_round"
|
| 226 |
+
] = is_last_timestep_in_round # Used later to group round timesteps if needed
|
| 227 |
+
return done, SimulationStepLog(rewards=rewards, info=info)
|
| 228 |
+
|
| 229 |
+
def get_obs(self):
|
| 230 |
+
"""Returns all agent observations in dict"""
|
| 231 |
+
return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids}
|
| 232 |
+
|
| 233 |
+
@abstractmethod
|
| 234 |
+
def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]:
|
| 235 |
+
pass
|
| 236 |
+
|
| 237 |
+
@abstractmethod
|
| 238 |
+
def get_obs_agent(self, agent_id):
|
| 239 |
+
pass
|
| 240 |
+
|
| 241 |
+
def get_state(self):
|
| 242 |
+
return self.state
|
| 243 |
+
|
| 244 |
+
def get_safe_copy(self):
|
| 245 |
+
"""Return a safe copy of the simulation."""
|
| 246 |
+
simulation_copy = copy.copy(self)
|
| 247 |
+
simulation_copy.state = copy.deepcopy(self.state)
|
| 248 |
+
return simulation_copy
|
| 249 |
+
|
| 250 |
+
@abstractmethod
|
| 251 |
+
def reset(self) -> dict[AgentId, NegotiationObs]:
|
| 252 |
+
pass
|
src_code_for_reproducibility/models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (260 Bytes). View file
|
|
|
src_code_for_reproducibility/models/__pycache__/adapter_training_wrapper.cpython-312.pyc
ADDED
|
Binary file (5.06 kB). View file
|
|
|
src_code_for_reproducibility/models/__pycache__/large_language_model_gemini_api.cpython-312.pyc
ADDED
|
Binary file (8.78 kB). View file
|
|
|
src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-312.pyc
ADDED
|
Binary file (16.5 kB). View file
|
|
|
src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc
ADDED
|
Binary file (3.31 kB). View file
|
|
|
src_code_for_reproducibility/training/__pycache__/tokenize_chats.cpython-312.pyc
ADDED
|
Binary file (5.97 kB). View file
|
|
|