dereckpichemila commited on
Commit
fa30e5a
·
verified ·
1 Parent(s): 9e1da36

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. .hydra/hydra.yaml +155 -0
  2. src_code_for_reproducibility/markov_games/__pycache__/__init__.cpython-310.pyc +0 -0
  3. src_code_for_reproducibility/markov_games/__pycache__/__init__.cpython-311.pyc +0 -0
  4. src_code_for_reproducibility/markov_games/__pycache__/agent.cpython-311.pyc +0 -0
  5. src_code_for_reproducibility/markov_games/__pycache__/alternative_actions_runner.cpython-311.pyc +0 -0
  6. src_code_for_reproducibility/markov_games/__pycache__/analysis_utils.cpython-310.pyc +0 -0
  7. src_code_for_reproducibility/markov_games/__pycache__/env_imports.cpython-310.pyc +0 -0
  8. src_code_for_reproducibility/markov_games/__pycache__/environment_imports.cpython-310.pyc +0 -0
  9. src_code_for_reproducibility/markov_games/__pycache__/export.cpython-310.pyc +0 -0
  10. src_code_for_reproducibility/markov_games/__pycache__/gather_and_export_utils.cpython-310.pyc +0 -0
  11. src_code_for_reproducibility/markov_games/__pycache__/gather_and_export_utils.cpython-311.pyc +0 -0
  12. src_code_for_reproducibility/markov_games/__pycache__/group_timesteps.cpython-311.pyc +0 -0
  13. src_code_for_reproducibility/markov_games/__pycache__/linear_runner.cpython-311.pyc +0 -0
  14. src_code_for_reproducibility/markov_games/__pycache__/markov_game.cpython-310.pyc +0 -0
  15. src_code_for_reproducibility/markov_games/__pycache__/markov_game.cpython-311.pyc +0 -0
  16. src_code_for_reproducibility/markov_games/__pycache__/mg_schemas.cpython-310.pyc +0 -0
  17. src_code_for_reproducibility/markov_games/__pycache__/mg_utils.cpython-310.pyc +0 -0
  18. src_code_for_reproducibility/markov_games/__pycache__/mg_utils.cpython-311.pyc +0 -0
  19. src_code_for_reproducibility/markov_games/__pycache__/render_utils.cpython-311.pyc +0 -0
  20. src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-310.pyc +0 -0
  21. src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-311.pyc +0 -0
  22. src_code_for_reproducibility/markov_games/__pycache__/rollout_tree_extract_utils.cpython-310.pyc +0 -0
  23. src_code_for_reproducibility/markov_games/__pycache__/run_markov_games.cpython-310.pyc +0 -0
  24. src_code_for_reproducibility/markov_games/__pycache__/runners.cpython-310.pyc +0 -0
  25. src_code_for_reproducibility/markov_games/__pycache__/scores.cpython-310.pyc +0 -0
  26. src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-310.pyc +0 -0
  27. src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-311.pyc +0 -0
  28. src_code_for_reproducibility/markov_games/__pycache__/two_chats_to_html.cpython-310.pyc +0 -0
  29. src_code_for_reproducibility/markov_games/__pycache__/types.cpython-310.pyc +0 -0
  30. src_code_for_reproducibility/markov_games/deal_no_deal/__pycache__/__init__.cpython-311.pyc +0 -0
  31. src_code_for_reproducibility/markov_games/deal_no_deal/__pycache__/dond_agent.cpython-311.pyc +0 -0
  32. src_code_for_reproducibility/markov_games/deal_no_deal/__pycache__/dond_simulation.cpython-311.pyc +0 -0
  33. src_code_for_reproducibility/markov_games/ipd/__init__.py +0 -0
  34. src_code_for_reproducibility/markov_games/ipd/__pycache__/__init__.cpython-311.pyc +0 -0
  35. src_code_for_reproducibility/markov_games/ipd/ipd_statistics.py +10 -0
  36. src_code_for_reproducibility/markov_games/negotiation/README.md +40 -0
  37. src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_simulation.cpython-311.pyc +0 -0
  38. src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_agent.cpython-311.pyc +0 -0
  39. src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_simulation.cpython-311.pyc +0 -0
  40. src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simulation.cpython-311.pyc +0 -0
  41. src_code_for_reproducibility/markov_games/negotiation/dond_agent.py +61 -0
  42. src_code_for_reproducibility/markov_games/negotiation/dond_simulation.py +153 -0
  43. src_code_for_reproducibility/markov_games/negotiation/nego_agent.py +174 -0
  44. src_code_for_reproducibility/markov_games/negotiation/nego_simulation.py +229 -0
  45. src_code_for_reproducibility/markov_games/negotiation/negotiation_statistics.py +44 -0
  46. src_code_for_reproducibility/markov_games/negotiation/no_press_nego_agent.py +48 -0
  47. src_code_for_reproducibility/markov_games/negotiation/no_press_nego_simulation.py +141 -0
  48. src_code_for_reproducibility/markov_games/negotiation/tas_agent.py +61 -0
  49. src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py +85 -0
  50. src_code_for_reproducibility/markov_games/negotiation/tas_rps_simulation.py +208 -0
.hydra/hydra.yaml ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ${oc.env:SCRATCH}/llm_negotiation/${now:%Y_%m}/${experiment.name}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task:
115
+ - experiment.name=tas_rps_no_regex_prev_ad_align_buffer_gae
116
+ job:
117
+ name: run
118
+ chdir: false
119
+ override_dirname: experiment.name=tas_rps_no_regex_prev_ad_align_buffer_gae
120
+ id: ???
121
+ num: ???
122
+ config_name: tas_rps_no_regex_prev_ad_align_buffer_gae
123
+ env_set: {}
124
+ env_copy: []
125
+ config:
126
+ override_dirname:
127
+ kv_sep: '='
128
+ item_sep: ','
129
+ exclude_keys: []
130
+ runtime:
131
+ version: 1.3.2
132
+ version_base: '1.1'
133
+ cwd: /home/mila/d/dereck.piche/llm_negotiation
134
+ config_sources:
135
+ - path: hydra.conf
136
+ schema: pkg
137
+ provider: hydra
138
+ - path: /home/mila/d/dereck.piche/llm_negotiation/configs
139
+ schema: file
140
+ provider: main
141
+ - path: ''
142
+ schema: structured
143
+ provider: schema
144
+ output_dir: /network/scratch/d/dereck.piche/llm_negotiation/2025_09/tas_rps_no_regex_prev_ad_align_buffer_gae
145
+ choices:
146
+ hydra/env: default
147
+ hydra/callbacks: null
148
+ hydra/job_logging: default
149
+ hydra/hydra_logging: default
150
+ hydra/hydra_help: default
151
+ hydra/help: default
152
+ hydra/sweeper: basic
153
+ hydra/launcher: basic
154
+ hydra/output: default
155
+ verbose: false
src_code_for_reproducibility/markov_games/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (164 Bytes). View file
 
src_code_for_reproducibility/markov_games/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (180 Bytes). View file
 
src_code_for_reproducibility/markov_games/__pycache__/agent.cpython-311.pyc ADDED
Binary file (3.47 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/alternative_actions_runner.cpython-311.pyc ADDED
Binary file (5.62 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/analysis_utils.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/env_imports.cpython-310.pyc ADDED
Binary file (644 Bytes). View file
 
src_code_for_reproducibility/markov_games/__pycache__/environment_imports.cpython-310.pyc ADDED
Binary file (822 Bytes). View file
 
src_code_for_reproducibility/markov_games/__pycache__/export.cpython-310.pyc ADDED
Binary file (11.9 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/gather_and_export_utils.cpython-310.pyc ADDED
Binary file (30.6 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/gather_and_export_utils.cpython-311.pyc ADDED
Binary file (41 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/group_timesteps.cpython-311.pyc ADDED
Binary file (6.71 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/linear_runner.cpython-311.pyc ADDED
Binary file (1.4 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/markov_game.cpython-310.pyc ADDED
Binary file (5.91 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/markov_game.cpython-311.pyc ADDED
Binary file (10.7 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/mg_schemas.cpython-310.pyc ADDED
Binary file (2.82 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/mg_utils.cpython-310.pyc ADDED
Binary file (1.63 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/mg_utils.cpython-311.pyc ADDED
Binary file (3.94 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/render_utils.cpython-311.pyc ADDED
Binary file (18.9 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-310.pyc ADDED
Binary file (3.26 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-311.pyc ADDED
Binary file (4.69 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/rollout_tree_extract_utils.cpython-310.pyc ADDED
Binary file (667 Bytes). View file
 
src_code_for_reproducibility/markov_games/__pycache__/run_markov_games.cpython-310.pyc ADDED
Binary file (829 Bytes). View file
 
src_code_for_reproducibility/markov_games/__pycache__/runners.cpython-310.pyc ADDED
Binary file (2.74 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/scores.cpython-310.pyc ADDED
Binary file (8.02 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-310.pyc ADDED
Binary file (3.51 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-311.pyc ADDED
Binary file (4.28 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/two_chats_to_html.cpython-310.pyc ADDED
Binary file (9.29 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/types.cpython-310.pyc ADDED
Binary file (404 Bytes). View file
 
src_code_for_reproducibility/markov_games/deal_no_deal/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (193 Bytes). View file
 
src_code_for_reproducibility/markov_games/deal_no_deal/__pycache__/dond_agent.cpython-311.pyc ADDED
Binary file (9.93 kB). View file
 
src_code_for_reproducibility/markov_games/deal_no_deal/__pycache__/dond_simulation.cpython-311.pyc ADDED
Binary file (18.7 kB). View file
 
src_code_for_reproducibility/markov_games/ipd/__init__.py ADDED
File without changes
src_code_for_reproducibility/markov_games/ipd/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (184 Bytes). View file
 
src_code_for_reproducibility/markov_games/ipd/ipd_statistics.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict
4
+
5
+ from mllm.markov_games.rollout_tree import SimulationStepLog
6
+
7
+
8
+ def avg_reward(sl: SimulationStepLog) -> Dict[str, float]:
9
+ # One value per agent at each step
10
+ return {aid: float(v) for aid, v in (sl.rewards or {}).items()}
src_code_for_reproducibility/markov_games/negotiation/README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Negotiation Games: core mechanics and variants
2
+
3
+ This family of games feature two agents who, in each round, may briefly communicate and then simultaneously propose how to split a fixed resource (most commonly 10 coins). Rewards are the amount kept multiplied by an agent’s per-unit value. The starting speaker alternates deterministically across rounds.
4
+
5
+ Communication is optional and variant-dependent: some settings encourage rich messaging to share private information, while others remove messaging entirely to focus on allocation behavior.
6
+
7
+ Proportional splitting is used when the two proposals exceed the available total: allocations are scaled proportionally rather than discarded. This preserves a useful learning signal even when agents over-claim.
8
+
9
+ ### Variants (in increasing difficulty)
10
+
11
+ - No‑Press Split
12
+ - Single item type (coins)
13
+ - No communication; agents go straight to making split proposals, with the starting player alternating deterministically.
14
+ - Motivation: mirrors no‑communication setups (e.g., Advantage Alignment) while keeping the split decision nontrivial.
15
+ - Deterministic Mode: values are fixed and public: one agent values coins at 10, the other at 1 (alternates each round).
16
+ - Stochastic Mode: values are random and uncorrelated.
17
+
18
+ - Trust-and-Split RPS (TAS-RPS)
19
+ - Single item type (coins)
20
+ - Each round, a rock–paper–scissors hand draw creates a strong asymmetry: the winner’s per-coin value is 10, the loser’s is 1.
21
+ - Each agent initially sees only their own hand and must communicate to coordinate an optimal split.
22
+ - Motivation: enforce large value disparity so one’s own value reveals little about the other’s (avoiding ceiling effects) and incentivize meaningful communication.
23
+
24
+ - Trust-and-Split (TAS)
25
+ - Single item type (coins); each round, each agent’s per-coin value is independently sampled in a broad range (e.g., 1–20).
26
+ - Each agent observes only their own value; they may use short messages to share and negotiate.
27
+ - Motivation: a simple blend that tests whether agents learn to exchange private information and coordinate proportional, value-aware splits.
28
+
29
+ - Deal-or-No-Deal (DOND)
30
+ - Introduced in [Deal or No Deal? End-to-End Learning for Negotiation Dialogues](https://arxiv.org/pdf/1706.05125)
31
+ - Multiple item types (typically "books", "hats" and "balls") with limited stocks; each agent has its own per-type values.
32
+ - A deal pays out only if both proposals exactly agree and respect the stock; otherwise no deal (zero reward) that round.
33
+ - Motivation: a known benchmark closer to real-world bargaining, where both parties must explicitly agree.
34
+
35
+
36
+
37
+
38
+
39
+
40
+
src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_simulation.cpython-311.pyc ADDED
Binary file (11.9 kB). View file
 
src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_agent.cpython-311.pyc ADDED
Binary file (7.72 kB). View file
 
src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_simulation.cpython-311.pyc ADDED
Binary file (8.1 kB). View file
 
src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simulation.cpython-311.pyc ADDED
Binary file (7.23 kB). View file
 
src_code_for_reproducibility/markov_games/negotiation/dond_agent.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import re
3
+ from collections.abc import Callable
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Tuple
6
+
7
+ from mllm.markov_games.agent import Agent
8
+ from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn
9
+ from mllm.markov_games.negotiation.dond_simulation import (
10
+ DealNoDealObs,
11
+ )
12
+ from mllm.markov_games.negotiation.nego_simulation import Split
13
+ from mllm.markov_games.negotiation.nego_agent import NegotiationAgent, NegotiationAgentState
14
+
15
+ class DealNoDealAgent(NegotiationAgent):
16
+ def __init__(
17
+ self,
18
+ *args,
19
+ **kwargs,
20
+ ):
21
+ super().__init__(*args, **kwargs)
22
+ self.intro_prompt = (
23
+ "You are {agent_id}. You are playing an iterated game. "
24
+ "At each round, you and other agent will try to distribute among yourselves items of types {item_types}. "
25
+ "You only know how much you value each item type, but not the other agent's values. "
26
+ "You can communicate with the other agent by sending up to {quota_messages_per_agent_per_round} short messages per round. "
27
+ "Each round, after exchanging messages, you and the other agent will submit a private proposal. "
28
+ "A deal is accepted only if both proposals match exactly and are within stock; otherwise no deal (0 points for both at that round). "
29
+ "The values of the items of the other agent at the previous round are revealed to you after each round. "
30
+ "Your goal is: {goal}."
31
+ )
32
+ self.new_round_prompt = ("New round {round_nb}. Items: {stock}. Your values: {values}. ")
33
+ self.last_round_prompt = ("Last round, other agent's values: {previous_values_coagent}. ")
34
+ self.send_split_prompt = ("Respond with <split>...</split> where you propose how many items of each type you want to keep.")
35
+
36
+ def get_message_regex(self, observation: DealNoDealObs) -> str:
37
+ return r"<message>[\s\S]{0,400}</message>"
38
+
39
+ def get_split_regex(self, observation: DealNoDealObs) -> str:
40
+ parts = []
41
+ for t in observation.item_types:
42
+ s = int(observation.quantities.get(t, 0))
43
+ allowed = "|".join(str(k) for k in range(0, s + 1))
44
+ rng = f"({allowed})"
45
+ parts.append(fr"<{t}>{rng}</{t}>")
46
+ items_block = "".join(parts)
47
+ return fr"(<split>{items_block}</split>)"
48
+
49
+ def get_split_action(self, policy_output: str, observation: DealNoDealObs) -> Split:
50
+ import re as _re
51
+ allocations: Dict[str, int] = {}
52
+ for t in observation.item_types:
53
+ m = _re.search(fr"<{t}>([0-9]+)</{t}>", policy_output)
54
+ if m:
55
+ allocations[t] = int(m.group(1))
56
+ else:
57
+ allocations[t] = 0
58
+ return Split(items_given_to_self=allocations)
59
+
60
+
61
+
src_code_for_reproducibility/markov_games/negotiation/dond_simulation.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Tuple
4
+
5
+ from numpy.random import default_rng
6
+
7
+ from mllm.markov_games.rollout_tree import SimulationStepLog
8
+ from mllm.markov_games.negotiation.nego_simulation import Split, NegotiationState, NegotiationObs, NegotiationSimulation
9
+ from mllm.utils.get_coagent_id import get_coagent_id
10
+
11
+
12
+ AgentId = str
13
+
14
+
15
+ @dataclass
16
+ class DealNoDealState(NegotiationState):
17
+ item_types: List[str]
18
+ values: Dict[AgentId, Dict[str, int]]
19
+
20
+ @dataclass
21
+ class DealNoDealObs(NegotiationObs):
22
+ my_values: Dict[str, int]
23
+ item_types: List[str]
24
+ previous_values_coagent: Dict[str, int] | None
25
+
26
+
27
+ def random_partition_integer(rng, total: int, parts: int) -> List[int]:
28
+ if parts <= 0:
29
+ return []
30
+ if total <= 0:
31
+ return [0 for _ in range(parts)]
32
+ cuts = sorted(rng.integers(0, total + 1, size=parts - 1).tolist())
33
+ vals = []
34
+ prev = 0
35
+ for c in cuts + [total]:
36
+ vals.append(c - prev)
37
+ prev = c
38
+ return vals
39
+
40
+ class DealNoDealSimulation(NegotiationSimulation):
41
+
42
+ def __init__(
43
+ self,
44
+ item_types: List[str] = ["books", "hats", "balls"],
45
+ *args,
46
+ **kwargs,
47
+ ):
48
+ super().__init__(item_types=item_types, *args, **kwargs)
49
+ self.reset()
50
+
51
+ def _other(self, agent_id: AgentId) -> AgentId:
52
+ return get_coagent_id(self.agent_ids, agent_id)
53
+
54
+ def _sample_stock(self) -> Dict[str, int]:
55
+ # total items between 5 and 7
56
+ total_items = int(self.rng.integers(5, 8))
57
+ # nonnegative per-type counts summing to total_items
58
+ parts = random_partition_integer(self.rng, total_items, len(self.item_types))
59
+ # allow zeros per type
60
+ return {t: int(c) for t, c in zip(self.item_types, parts)}
61
+
62
+ def _sample_values_pair(self) -> Dict[AgentId, Dict[str, int]]:
63
+ # Each agent has integer non-negative values that sum to 10
64
+ # Each item type valued by at least one agent
65
+ # Some item type valued by both agents
66
+ while True:
67
+ vals_a = random_partition_integer(self.rng, 10, len(self.item_types))
68
+ vals_b = random_partition_integer(self.rng, 10, len(self.item_types))
69
+ a = {t: int(v) for t, v in zip(self.item_types, vals_a)}
70
+ b = {t: int(v) for t, v in zip(self.item_types, vals_b)}
71
+ # each item valued by at least one
72
+ ok1 = all((a[t] > 0) or (b[t] > 0) for t in self.item_types)
73
+ # some item valued by both
74
+ ok2 = any((a[t] > 0) and (b[t] > 0) for t in self.item_types)
75
+ if ok1 and ok2:
76
+ return {self.agent_ids[0]: a, self.agent_ids[1]: b}
77
+
78
+ def _is_valid_allocation(self, allocation: Dict[str, int], stock: Dict[str, int]) -> bool:
79
+ for t in self.item_types:
80
+ v = allocation.get(t)
81
+ if v is None:
82
+ return False
83
+ if not isinstance(v, int):
84
+ return False
85
+ if v < 0 or v > int(stock.get(t, 0)):
86
+ return False
87
+ return True
88
+
89
+ def set_new_round_of_variant(self):
90
+ # Keep same values, resample stock
91
+ self.state.quantities = self._sample_stock()
92
+
93
+ def get_info_of_variant(self, state: NegotiationState, actions: Dict[AgentId, Any]) -> Dict[str, Any]:
94
+ return {
95
+ "quantities": copy.deepcopy(state.quantities),
96
+ "values": copy.deepcopy(state.values),
97
+ 'splits': copy.deepcopy(state.splits),
98
+ }
99
+
100
+ def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]:
101
+ """
102
+ Returns the rewards for each agent.
103
+ """
104
+ split_a = splits[self.agent_ids[0]].items_given_to_self
105
+ split_b = splits[self.agent_ids[1]].items_given_to_self
106
+ rewards = {self.agent_ids[0]: 0, self.agent_ids[1]: 0}
107
+ for t in self.item_types:
108
+ # If not complementary, return 0!
109
+ if not split_a[t] + split_b[t] == self.state.quantities[t]:
110
+ return {self.agent_ids[0]: 0, self.agent_ids[1]: 0}
111
+ rewards[self.agent_ids[0]] += split_a[t] * self.state.values[self.agent_ids[0]][t]
112
+ rewards[self.agent_ids[1]] += split_b[t] * self.state.values[self.agent_ids[1]][t]
113
+ return rewards
114
+
115
+ def get_obs(self):
116
+ return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids}
117
+
118
+ def get_obs_agent(self, agent_id):
119
+ other_id = self._other(agent_id)
120
+ obs = DealNoDealObs(
121
+ round_nb=self.state.round_nb,
122
+ last_message=self.state.last_message,
123
+ current_agent=self.state.current_agent,
124
+ quantities=copy.deepcopy(self.state.quantities),
125
+ value=0.0, # unused in DOND
126
+ other_agent_split=None, # not meaningful until split
127
+ split_phase=self.state.split_phase,
128
+ quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round,
129
+ my_values=copy.deepcopy(self.state.values[agent_id]),
130
+ item_types=list(self.item_types),
131
+ previous_values_coagent=copy.deepcopy(self.state.values.get(other_id, {})),
132
+ )
133
+ return obs
134
+
135
+ def reset(self):
136
+ start_agent = self.agent_ids[self._starting_agent_index]
137
+ stock = self._sample_stock()
138
+ values = self._sample_values_pair()
139
+ self.state = DealNoDealState(
140
+ round_nb=0,
141
+ last_message="",
142
+ current_agent=start_agent,
143
+ quantities=stock,
144
+ values=values,
145
+ previous_values=None,
146
+ splits={aid: None for aid in self.agent_ids},
147
+ nb_messages_sent={aid: 0 for aid in self.agent_ids},
148
+ split_phase=False,
149
+ item_types=list(self.item_types),
150
+ )
151
+ return self.get_obs()
152
+
153
+
src_code_for_reproducibility/markov_games/negotiation/nego_agent.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from abc import abstractmethod
3
+ from collections.abc import Callable
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Tuple
6
+
7
+ from mllm.markov_games.agent import Agent
8
+ from mllm.markov_games.negotiation.nego_simulation import Message, NegotiationObs, Split
9
+ from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn
10
+
11
+
12
+ @dataclass
13
+ class NegotiationAgentState:
14
+ round_nb: int
15
+ nb_messages_sent_this_round: int
16
+ chat_counter: int
17
+ chat_history: List[ChatTurn]
18
+
19
+
20
+ class NegotiationAgent(Agent):
21
+ def __init__(
22
+ self,
23
+ seed: int,
24
+ agent_id: str,
25
+ agent_name: str,
26
+ policy: Callable[[List[Dict]], str],
27
+ goal: str,
28
+ ):
29
+ self.seed = seed
30
+ self.agent_id = agent_id
31
+ self.agent_name = agent_name
32
+ self.policy = policy
33
+ self.goal = goal
34
+ self.state = NegotiationAgentState(
35
+ round_nb=0, nb_messages_sent_this_round=0, chat_counter=0, chat_history=[]
36
+ )
37
+
38
+ # Implemented in variants
39
+ self.intro_prompt = ""
40
+ self.new_round_prompt = ""
41
+ self.last_round_prompt = ""
42
+ self.send_split_prompt = ""
43
+ self.wait_for_message_prompt = ""
44
+ self.last_message_prompt = ""
45
+ self.send_message_prompt = ""
46
+
47
+ @abstractmethod
48
+ def get_message_regex(self, observation: NegotiationObs) -> str:
49
+ pass
50
+
51
+ @abstractmethod
52
+ def get_split_regex(self, observation: NegotiationObs) -> str:
53
+ pass
54
+
55
+ @abstractmethod
56
+ def get_split_action(
57
+ self, policy_output: str, observation: NegotiationObs
58
+ ) -> Split:
59
+ pass
60
+
61
+ async def act(self, observation: NegotiationObs) -> Tuple[Any, AgentActLog]:
62
+ is_our_turn = observation.current_agent == self.agent_id
63
+ action: Any = None
64
+ round_nb = observation.round_nb
65
+
66
+ prompt_parts: List[str] = []
67
+ obs_ctx = vars(observation)
68
+
69
+ #######################################
70
+ # build user prompt
71
+ #######################################
72
+
73
+ # First-ever call
74
+ is_intro = round_nb == 0 and self.state.chat_counter == 0
75
+ if is_intro:
76
+ prompt_parts.append(
77
+ self.intro_prompt.format(
78
+ goal=self.goal, agent=self.agent_name, **obs_ctx
79
+ )
80
+ )
81
+
82
+ # New round
83
+ is_new_round = round_nb > self.state.round_nb
84
+ if is_new_round or is_intro:
85
+ self.state.nb_messages_sent_this_round = 0
86
+ if not is_intro:
87
+ prompt_parts.append(self.last_round_prompt.format(**obs_ctx))
88
+ prompt_parts.append(self.new_round_prompt.format(**obs_ctx))
89
+ self.state.round_nb = round_nb
90
+
91
+ # Wait for message
92
+ if not is_our_turn and not observation.split_phase:
93
+ prompt_parts.append(self.wait_for_message_prompt.format(**obs_ctx))
94
+
95
+ # Get last message
96
+ if is_our_turn and not is_new_round and not is_intro:
97
+ prompt_parts.append(self.last_message_prompt.format(**obs_ctx))
98
+
99
+ # Prompt to send message
100
+ must_send_message = not observation.split_phase and is_our_turn
101
+ if must_send_message:
102
+ prompt_parts.append(self.send_message_prompt.format(**obs_ctx))
103
+
104
+ # Prompt to give split
105
+ must_send_split = not must_send_message and observation.split_phase
106
+ if must_send_split:
107
+ prompt_parts.append(self.send_split_prompt.format(**obs_ctx))
108
+
109
+ # Append one ChatTurn with is_state_end=True
110
+ user_prompt = "\n".join(prompt_parts)
111
+ self.state.chat_history.append(
112
+ ChatTurn(
113
+ agent_id=self.agent_id,
114
+ role="user",
115
+ content=user_prompt,
116
+ is_state_end=True,
117
+ )
118
+ )
119
+
120
+ #######################################
121
+ # Get policy action
122
+ #######################################
123
+
124
+ # Query policy for the appropriate format
125
+ if must_send_message:
126
+ return_regex = self.get_message_regex(observation)
127
+ policy_output = await self.policy(
128
+ prompt=[c.dict() for c in self.state.chat_history],
129
+ regex=return_regex,
130
+ )
131
+ self.state.chat_history.append(
132
+ ChatTurn(
133
+ agent_id=self.agent_id,
134
+ role="assistant",
135
+ content=policy_output,
136
+ is_state_end=False,
137
+ )
138
+ )
139
+ action = Message(message=policy_output)
140
+ self.state.nb_messages_sent_this_round += 1
141
+
142
+ elif must_send_split:
143
+ return_regex = self.get_split_regex(observation)
144
+ policy_output = await self.policy(
145
+ prompt=[c.dict() for c in self.state.chat_history],
146
+ regex=return_regex,
147
+ )
148
+ self.state.chat_history.append(
149
+ ChatTurn(
150
+ agent_id=self.agent_id,
151
+ role="assistant",
152
+ content=policy_output,
153
+ is_state_end=False,
154
+ )
155
+ )
156
+ action = self.get_split_action(policy_output, observation)
157
+ else:
158
+ action = None
159
+
160
+ agent_step_log = AgentActLog(
161
+ chat_turns=self.state.chat_history[self.state.chat_counter :], info=None
162
+ )
163
+ self.state.chat_counter = len(self.state.chat_history)
164
+ return action, agent_step_log
165
+
166
+ def get_safe_copy(self):
167
+ agent_copy = copy.copy(self)
168
+ agent_copy.state = copy.deepcopy(self.state)
169
+ return agent_copy
170
+
171
+ def reset(self):
172
+ self.state = NegotiationAgentState(
173
+ round_nb=0, nb_messages_sent_this_round=0, chat_counter=0, chat_history=[]
174
+ )
src_code_for_reproducibility/markov_games/negotiation/nego_simulation.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Negotiation simulation environment
3
+ other agent is set at the start of every round. Even though current agent changes over message turns in a round.
4
+ """
5
+ import copy
6
+ from abc import abstractmethod
7
+ from dataclasses import dataclass
8
+ from typing import Any, Dict, List, Tuple
9
+
10
+ from numpy.random import default_rng
11
+
12
+ from mllm.markov_games.rollout_tree import SimulationStepLog
13
+ from mllm.markov_games.simulation import Simulation
14
+ from mllm.utils.get_coagent_id import get_coagent_id
15
+
16
+ AgentId = str
17
+
18
+
19
+ @dataclass
20
+ class Split:
21
+ items_given_to_self: Dict[str, int]
22
+
23
+
24
+ @dataclass
25
+ class Message:
26
+ message: str
27
+
28
+
29
+ @dataclass # gets extended by variants
30
+ class NegotiationState:
31
+ round_nb: int
32
+ last_message: str
33
+ current_agent: AgentId
34
+ quantities: Dict[str, int]
35
+ values: Dict[AgentId, float]
36
+ splits: Dict[AgentId, Split | None]
37
+ nb_messages_sent: Dict[AgentId, int]
38
+ previous_values: Dict[AgentId, float] | None
39
+ previous_splits: Dict[AgentId, Split | None] | None
40
+ previous_points: Dict[AgentId, float] | None
41
+ split_phase: bool
42
+
43
+
44
+ @dataclass # gets extended by variants
45
+ class NegotiationObs:
46
+ round_nb: int
47
+ last_message: str
48
+ quota_messages_per_agent_per_round: int
49
+ current_agent: AgentId
50
+ other_agent: str
51
+ quantities: Dict[str, int]
52
+ item_types: List[str]
53
+ value: float
54
+ split_phase: bool
55
+ last_split_agent: int | None
56
+ last_value_agent: float | None
57
+ last_points_agent: float | None
58
+ last_split_coagent: int | None
59
+ last_value_coagent: float | None
60
+ last_points_coagent: float | None
61
+
62
+
63
+ def compute_tas_style_rewards(
64
+ agent_ids: List[AgentId],
65
+ values: Dict[AgentId, float],
66
+ splits: Dict[AgentId, Split],
67
+ max_coins: int,
68
+ ) -> Dict[AgentId, float]:
69
+ """
70
+ TAS-like reward computation: if sum of proposed coins exceeds max_coins,
71
+ allocate proportionally. Otherwise, use proposed amounts directly.
72
+ Rewards are quantity_kept * per-coin value for each agent.
73
+ """
74
+ a0, a1 = agent_ids[0], agent_ids[1]
75
+ coins_to_self_0 = int(
76
+ (splits[a0].items_given_to_self.get("coins", 0))
77
+ if splits[a0] is not None
78
+ else 0
79
+ )
80
+ coins_to_self_1 = int(
81
+ (splits[a1].items_given_to_self.get("coins", 0))
82
+ if splits[a1] is not None
83
+ else 0
84
+ )
85
+ denom = max(int(max_coins), coins_to_self_0 + coins_to_self_1)
86
+ q0 = float(max_coins) * float(coins_to_self_0) / float(denom)
87
+ q1 = float(max_coins) * float(coins_to_self_1) / float(denom)
88
+ r0 = q0 * float(values[a0])
89
+ r1 = q1 * float(values[a1])
90
+ return {a0: r0, a1: r1}
91
+
92
+
93
+ class NegotiationSimulation(Simulation):
94
+ def __init__(
95
+ self,
96
+ agent_ids: List[AgentId],
97
+ agent_names: List[str],
98
+ seed: int,
99
+ nb_of_rounds: int,
100
+ quota_messages_per_agent_per_round: int,
101
+ item_types: List[str] | None = None,
102
+ ):
103
+ self.seed = seed
104
+ self.rng = default_rng(self.seed)
105
+ self.agent_ids = list(agent_ids)
106
+ self.agent_names = agent_names
107
+ self.agent_id_to_name = {
108
+ agent_id: agent_name for agent_id, agent_name in zip(agent_ids, agent_names)
109
+ }
110
+ self.nb_of_rounds = int(nb_of_rounds)
111
+ self.quota_messages_per_agent_per_round = int(
112
+ quota_messages_per_agent_per_round
113
+ )
114
+ self.item_types = item_types or ["coins"]
115
+ self.state: NegotiationState | None = None
116
+ self._starting_agent_index = self.rng.choice([0, 1])
117
+ self.reset()
118
+
119
+ def _other(self, agent_id: AgentId) -> AgentId:
120
+ return get_coagent_id(self.agent_ids, agent_id)
121
+
122
+ @abstractmethod
123
+ def set_new_round_of_variant(self):
124
+ pass
125
+
126
+ @abstractmethod
127
+ def get_info_of_variant(
128
+ self, state: NegotiationState, actions: Dict[AgentId, Any]
129
+ ) -> Dict[str, Any]:
130
+ pass
131
+
132
+ def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]:
133
+ """
134
+ Returns terminated, step_log
135
+ """
136
+ assert self.state is not None
137
+ current_agent = self.state.current_agent
138
+ a0, a1 = self.agent_ids[0], self.agent_ids[1]
139
+ action = actions.get(current_agent)
140
+
141
+ # Split phase: require both splits in the same timestep
142
+ if self.state.split_phase:
143
+ action_a0 = actions.get(a0)
144
+ action_a1 = actions.get(a1)
145
+ have_both_splits = isinstance(action_a0, Split) and isinstance(
146
+ action_a1, Split
147
+ )
148
+ if not have_both_splits:
149
+ rewards = {agent_id: 0.0 for agent_id in self.agent_ids}
150
+ return False, SimulationStepLog(
151
+ rewards=rewards, info={"type": "waiting_for_splits"}
152
+ )
153
+
154
+ # Record splits
155
+ self.state.splits[a0] = action_a0
156
+ self.state.splits[a1] = action_a1
157
+
158
+ # Compute rewards and end round
159
+ rewards = self.get_rewards(self.state.splits)
160
+
161
+ # Info
162
+ info = self.get_info_of_variant(self.state, actions)
163
+
164
+ # Prepare next round
165
+ # Alternate starting agent
166
+ self.state.round_nb += 1
167
+ self._starting_agent_index = 1 - self._starting_agent_index
168
+ self.state.current_agent = self.agent_ids[self._starting_agent_index]
169
+ self.state.other_agent = self.agent_id_to_name[
170
+ self._other(self.state.current_agent)
171
+ ]
172
+ self.set_new_round_of_variant() # variant specific
173
+ self.state.previous_splits = copy.deepcopy(self.state.splits)
174
+ self.state.previous_points = copy.deepcopy(rewards)
175
+ self.state.last_message = ""
176
+ self.state.splits = {agent_id: None for agent_id in self.agent_ids}
177
+ self.state.nb_messages_sent = {agent_id: 0 for agent_id in self.agent_ids}
178
+ is_last_timestep_in_round = True
179
+ done = self.state.round_nb >= self.nb_of_rounds
180
+
181
+ # Message phase
182
+ elif isinstance(action, Message):
183
+ self.state.last_message = action.message
184
+ self.state.nb_messages_sent[current_agent] += 1
185
+
186
+ # Move turn to other agent
187
+ self.state.current_agent = self._other(current_agent)
188
+
189
+ # If both agents have reached their message quota, enter split phase
190
+ if all(
191
+ self.state.nb_messages_sent[agent_id]
192
+ >= self.quota_messages_per_agent_per_round
193
+ for agent_id in self.agent_ids
194
+ ):
195
+ self.state.split_phase = True
196
+ is_last_timestep_in_round = False
197
+ done = False
198
+ rewards = {agent_id: 0.0 for agent_id in self.agent_ids}
199
+ info = {"type": "message"}
200
+
201
+ info[
202
+ "is_last_timestep_in_round"
203
+ ] = is_last_timestep_in_round # Used later to group round timesteps if needed
204
+ return done, SimulationStepLog(rewards=rewards, info=info)
205
+
206
+ def get_obs(self):
207
+ """Returns all agent observations in dict"""
208
+ return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids}
209
+
210
+ @abstractmethod
211
+ def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]:
212
+ pass
213
+
214
+ @abstractmethod
215
+ def get_obs_agent(self, agent_id):
216
+ pass
217
+
218
+ def get_state(self):
219
+ return self.state
220
+
221
+ def get_safe_copy(self):
222
+ """Return a safe copy of the simulation."""
223
+ simulation_copy = copy.copy(self)
224
+ simulation_copy.state = copy.deepcopy(self.state)
225
+ return simulation_copy
226
+
227
+ @abstractmethod
228
+ def reset(self) -> dict[AgentId, NegotiationObs]:
229
+ pass
src_code_for_reproducibility/markov_games/negotiation/negotiation_statistics.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict
4
+
5
+ from mllm.markov_games.rollout_tree import SimulationStepLog
6
+
7
+
8
+ def split_greed(sl: SimulationStepLog) -> Dict[str, float] | None:
9
+ info = sl.info or {}
10
+ if not info or not info.get("is_last_timestep_in_round"):
11
+ return None
12
+ quantities = info.get("quantities") or {}
13
+ denom = float(quantities.get("coins", 1.0)) or 1.0
14
+ splits = info.get("splits") or {}
15
+ out: Dict[str, float] = {}
16
+ for aid, split in splits.items():
17
+ try:
18
+ out[str(aid)] = float(split["items_given_to_self"]["coins"]) / denom
19
+ except Exception:
20
+ continue
21
+ return out
22
+
23
+
24
+ def split_efficiency(sl: SimulationStepLog) -> Dict[str, float] | None:
25
+ info = sl.info or {}
26
+ if not info or not info.get("is_last_timestep_in_round"):
27
+ return None
28
+ quantities = info.get("quantities") or {}
29
+ denom = float(quantities.get("coins", 1.0)) or 1.0
30
+ values = info.get("values") or {}
31
+ if not values:
32
+ return None
33
+ try:
34
+ max_val = max(float(v) for v in values.values())
35
+ except Exception:
36
+ return None
37
+ if not denom or not max_val:
38
+ return None
39
+ achieved = sum(float(v) for v in (sl.rewards or {}).values())
40
+ max_reward = denom * max_val
41
+ if not max_reward:
42
+ return None
43
+ # Efficiency is a global metric; emit same value for a special key "all"
44
+ return {"all": achieved / max_reward}
src_code_for_reproducibility/markov_games/negotiation/no_press_nego_agent.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Tuple
2
+
3
+ from mllm.markov_games.negotiation.nego_agent import (
4
+ NegotiationAgent,
5
+ NegotiationAgentState,
6
+ )
7
+ from mllm.markov_games.negotiation.nego_simulation import Split
8
+ from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressObs
9
+ from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn
10
+
11
+
12
+ class NoPressAgent(NegotiationAgent):
13
+ def __init__(self, *args, **kwargs):
14
+ super().__init__(*args, **kwargs)
15
+ # No communication in this variant
16
+ self.intro_prompt = (
17
+ "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n"
18
+ "Setup:\n"
19
+ "1. The game consists of multiple independent rounds.\n"
20
+ "2. In each round, there are 10 coins to split between the two agents.\n"
21
+ "3. Each round, both agents are randomly assigned a value of either 1 or 10 per coin.\n"
22
+ "4. You can observe values of both agents.\n"
23
+ "5. Because assignments are random, both agents are equally likely to have same expected per-coin value.\n"
24
+ "\n"
25
+ "Protocol:\n"
26
+ "1. Both agents simultaneously propose how many coins they keep.\n"
27
+ "4. If the total sum of proposals is less than or equal to 10, both agents receive their proposals.\n"
28
+ "5. If the total sum of proposals exceeds 10, the coins are allocated proportionally.\n"
29
+ "6. Your points for the round = (coins you receive) x (your per-coin value for that round). \n"
30
+ "7. The points are accumulated across rounds.\n"
31
+ "Your goal: {goal}\n"
32
+ )
33
+ self.new_round_prompt = "In this round, your per-coin value is {value} and {other_agent}'s per-coin value is {other_value}."
34
+ self.last_round_prompt = "In the last round, your per-coin value was {last_value_agent} and {other_agent}'s per-coin value was {last_value_coagent}.\nYou proposed {last_split_agent} coins and earned {last_points_agent} points, while {other_agent} proposed {last_split_coagent} coins and earned {last_points_coagent} points."
35
+ self.send_split_prompt = "Respond with <coins_to_self> X </coins_to_self> where X is the number of coins you propose for yourself, between 0 and 10 inclusive."
36
+
37
+ def get_message_regex(self, observation: NoPressObs) -> str:
38
+ return r"^$" # No messages allowed
39
+
40
+ def get_split_regex(self, observation: NoPressObs) -> str:
41
+ return r"<coins_to_self> ?(10|[0-9]) ?</coins_to_self>"
42
+
43
+ def get_split_action(self, policy_output: str, observation: NoPressObs) -> Split:
44
+ import re as _re
45
+
46
+ m = _re.search(r"<coins_to_self> ?(10|[0-9]) ?</coins_to_self>", policy_output)
47
+ coins_int = int(m.group(1)) if m else int(policy_output)
48
+ return Split(items_given_to_self={"coins": coins_int})
src_code_for_reproducibility/markov_games/negotiation/no_press_nego_simulation.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Tuple
4
+
5
+ from mllm.markov_games.negotiation.nego_simulation import (
6
+ NegotiationObs,
7
+ NegotiationSimulation,
8
+ NegotiationState,
9
+ Split,
10
+ compute_tas_style_rewards,
11
+ )
12
+
13
+ AgentId = str
14
+
15
+
16
+ @dataclass
17
+ class NoPressState(NegotiationState):
18
+ pass
19
+
20
+
21
+ @dataclass
22
+ class NoPressObs(NegotiationObs):
23
+ other_value: float
24
+
25
+
26
+ class NoPressSimulation(NegotiationSimulation):
27
+ def __init__(
28
+ self,
29
+ deterministic: bool,
30
+ *args,
31
+ **kwargs,
32
+ ):
33
+ self.deterministic = deterministic
34
+ super().__init__(*args, **kwargs)
35
+
36
+ def _sample_values(self) -> Dict[AgentId, float]:
37
+ v = float(int(self.rng.choice([1, 10])))
38
+ return {self.agent_ids[0]: v, self.agent_ids[1]: 10.0 if v == 1.0 else 1.0}
39
+
40
+ def set_new_round_of_variant(self):
41
+ self.state.previous_values = copy.deepcopy(self.state.values)
42
+ self.state.quantities = {"coins": 10.0}
43
+ if self.deterministic:
44
+ self.state.values = {
45
+ aid: 1.0 if aid == self.state.current_agent else 10.0
46
+ for aid in self.agent_ids
47
+ }
48
+ else:
49
+ self.state.values = self._sample_values()
50
+ self.state.split_phase = True
51
+
52
+ def get_info_of_variant(
53
+ self, state: NegotiationState, actions: Dict[AgentId, Any]
54
+ ) -> Dict[str, Any]:
55
+ return {
56
+ "quantities": copy.deepcopy(state.quantities),
57
+ "values": copy.deepcopy(state.values),
58
+ "splits": copy.deepcopy(state.splits),
59
+ }
60
+
61
+ def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]:
62
+ return compute_tas_style_rewards(
63
+ self.agent_ids, self.state.values, splits, 10.0
64
+ )
65
+
66
+ def get_obs(self):
67
+ return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids}
68
+
69
+ def get_obs_agent(self, agent_id):
70
+ other_id = self._other(agent_id)
71
+ last_value_coagent = (
72
+ None
73
+ if self.state.previous_values is None
74
+ else self.state.previous_values.get(other_id)
75
+ )
76
+ last_points_coagent = (
77
+ None
78
+ if self.state.previous_points is None
79
+ else round(self.state.previous_points.get(other_id), 1)
80
+ )
81
+ last_value_agent = (
82
+ None
83
+ if self.state.previous_values is None
84
+ else self.state.previous_values.get(agent_id)
85
+ )
86
+ last_points_agent = (
87
+ None
88
+ if self.state.previous_points is None
89
+ else round(self.state.previous_points.get(agent_id), 1)
90
+ )
91
+ last_split_coagent = None
92
+ last_split_agent = None
93
+ if self.state.previous_splits is not None:
94
+ last_split_coagent = self.state.previous_splits[
95
+ other_id
96
+ ].items_given_to_self["coins"]
97
+ last_split_agent = self.state.previous_splits[agent_id].items_given_to_self[
98
+ "coins"
99
+ ]
100
+ obs = NoPressObs(
101
+ round_nb=self.state.round_nb,
102
+ last_message="",
103
+ quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round,
104
+ current_agent=self.state.current_agent,
105
+ other_agent=other_id,
106
+ quantities={"coins": 10},
107
+ item_types=self.item_types,
108
+ value=self.state.values[agent_id],
109
+ split_phase=self.state.split_phase,
110
+ last_split_agent=last_split_agent,
111
+ last_value_agent=last_value_agent,
112
+ last_points_agent=last_points_agent,
113
+ last_split_coagent=last_split_coagent,
114
+ last_value_coagent=last_value_coagent,
115
+ last_points_coagent=last_points_coagent,
116
+ other_value=self.state.values[other_id],
117
+ )
118
+ return obs
119
+
120
+ def reset(self):
121
+ start_agent = self.agent_ids[self._starting_agent_index]
122
+ if self.deterministic:
123
+ values = {
124
+ aid: 1.0 if aid == start_agent else 10.0 for aid in self.agent_ids
125
+ }
126
+ else:
127
+ values = self._sample_values()
128
+ self.state = NoPressState(
129
+ round_nb=0,
130
+ last_message="",
131
+ current_agent=start_agent,
132
+ quantities={"coins": 10.0},
133
+ values=values,
134
+ previous_values=None,
135
+ splits={aid: None for aid in self.agent_ids},
136
+ nb_messages_sent={aid: 0 for aid in self.agent_ids},
137
+ split_phase=True,
138
+ previous_splits=None,
139
+ previous_points=None,
140
+ )
141
+ return self.get_obs()
src_code_for_reproducibility/markov_games/negotiation/tas_agent.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mllm.markov_games.negotiation.nego_agent import NegotiationAgent
2
+ from mllm.markov_games.negotiation.nego_simulation import Split
3
+ from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitObs
4
+
5
+
6
+ class TrustAndSplitAgent(NegotiationAgent):
7
+ def __init__(self, num_message_chars, *args, **kwargs):
8
+ self.num_message_chars = num_message_chars
9
+ super().__init__(*args, **kwargs)
10
+ self.intro_prompt = (
11
+ "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n"
12
+ "Setup:\n"
13
+ "1. The game has multiple independent rounds.\n"
14
+ "2. In each round, there are 10 coins to split between the two agents.\n"
15
+ "3. Both agents are assigned a private per-coin value between 1 and 20 (inclusive) in each round.\n"
16
+ "4. Because assignments are random, both agents are equally likely to have same expected per-coin value.\n"
17
+ "\n"
18
+ "Protocol:\n"
19
+ "1. At the start of the round, one agent begins the conversation. The starting role alternates each round.\n"
20
+ "2. Agents exchange a short chat ({quota_messages_per_agent_per_round} messages per round per agent) to negotiate how to split the 10 coins.\n"
21
+ " - Use this chat to communicate your private per-coin value to make informed proposals.\n"
22
+ "3. After the chat, both agents simultaneously propose how many coins they keep.\n"
23
+ "4. If the total sum of proposals is less than or equal to 10, both agents receive their proposals.\n"
24
+ "5. If the total sum of proposals exceeds 10, the coins are allocated proportionally.\n"
25
+ "6. Your points for the round = (coins you receive) x (your per-coin value for that round). \n"
26
+ "7. The points are accumulated across rounds.\n"
27
+ "Your goal: {goal}\n"
28
+ )
29
+ self.new_round_prompt = "A new round begins\n" "Your per-coin value is {value}."
30
+ self.last_round_prompt = (
31
+ "Round summary:\n"
32
+ " - Your value per coin: {last_value_agent}\n"
33
+ " - {other_agent}'s value per coin: {last_value_coagent}\n"
34
+ " - You proposed: {last_split_agent} coins\n"
35
+ " - You earned: {last_points_agent} points\n"
36
+ " - {other_agent} proposed: {last_split_coagent} coins\n"
37
+ " - {other_agent} earned: {last_points_coagent} points\n"
38
+ " - Round complete.\n"
39
+ )
40
+ self.send_split_prompt = (
41
+ "Submit your proposal\n"
42
+ "Respond with <coins_to_self> x </coins_to_self> where x is an integer in [0, 10]."
43
+ )
44
+ self.wait_for_message_prompt = "Wait for {other_agent} to send a message..."
45
+ self.last_message_prompt = "{other_agent} said: {last_message}"
46
+ self.send_message_prompt = f"Send your message now in <message>...</message> (<={self.num_message_chars} chars)."
47
+
48
+ def get_message_regex(self, observation: TrustAndSplitObs) -> str:
49
+ return rf"<message>[\s\S]{{0,{self.num_message_chars}}}</message>"
50
+
51
+ def get_split_regex(self, observation: TrustAndSplitObs) -> str:
52
+ return r"<coins_to_self> ?(10|[0-9]) ?</coins_to_self>"
53
+
54
+ def get_split_action(
55
+ self, policy_output: str, observation: TrustAndSplitObs
56
+ ) -> Split:
57
+ import re as _re
58
+
59
+ m = _re.search(r"<coins_to_self> ?(10|[0-9]) ?</coins_to_self>", policy_output)
60
+ coins_int = int(m.group(1)) if m else int(policy_output)
61
+ return Split(items_given_to_self={"coins": coins_int})
src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from collections.abc import Callable
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Tuple
5
+
6
+ from mllm.markov_games.agent import Agent
7
+ from mllm.markov_games.negotiation.nego_agent import (
8
+ Message,
9
+ NegotiationAgent,
10
+ NegotiationAgentState,
11
+ Split,
12
+ )
13
+ from mllm.markov_games.negotiation.tas_rps_simulation import TrustAndSplitRPSObs
14
+ from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn
15
+
16
+
17
+ class TrustAndSplitRPSAgent(NegotiationAgent):
18
+ def __init__(
19
+ self,
20
+ num_message_chars: int,
21
+ *args,
22
+ **kwargs,
23
+ ):
24
+ self.num_message_chars = num_message_chars
25
+ super().__init__(*args, **kwargs)
26
+ self.intro_prompt = (
27
+ "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n"
28
+ "\n"
29
+ "Setup:\n"
30
+ "1. The game has multiple independent rounds.\n"
31
+ "2. In each round, there are 10 coins to split between the two agents.\n"
32
+ "3. Each agent's per-coin value for that round is determined as follows:\n"
33
+ " - Both agents are randomly assigned a rock, paper or scissors hands\n"
34
+ " - Rock has the upper hand over scissors, scissors has the upper hand over paper and paper has the upper hand over rock.\n"
35
+ " - The agent with the upper hand has a per-coin value of 10.\n"
36
+ " - The agent with the lower hand has a per-coin value of 1.\n"
37
+ "4. You only see your own hand, but you may communicate it in messages and infer your value based on the other agent's hand.\n"
38
+ "5. Over many rounds both agents are equally likely to have the upper and lower hand.\n"
39
+ "\n"
40
+ "Protocol:\n"
41
+ "1. At the start of the round, one agent begins the conversation. The starting role alternates each round.\n"
42
+ "2. Agents exchange a short chat ({quota_messages_per_agent_per_round} messages per round per agent) to negotiate how to split the 10 coins.\n"
43
+ " - Use this chat to communicate your hand so that both agents can determine their per-coin values.\n"
44
+ "3. After the chat, both agents simultaneously propose how many coins they keep.\n"
45
+ "4. If the total sum of proposals is less than or equal to 10, both agents receive their proposals.\n"
46
+ "5. If the total sum of proposals exceeds 10, the coins are allocated proportionally.\n"
47
+ "6. Your points for the round = (coins you receive) x (your per-coin value for that round). \n"
48
+ "7. The points are accumulated across rounds.\n"
49
+ "Your goal: {goal}\n"
50
+ )
51
+ self.new_round_prompt = "A new round begins\n" "Your hand is {hand}."
52
+ self.last_round_prompt = (
53
+ "Round summary:\n"
54
+ " - Your hand: {last_hand_agent}\n"
55
+ " - {other_agent}'s hand: {last_hand_coagent}\n"
56
+ " - Your value per coin: {last_value_agent}\n"
57
+ " - {other_agent}'s value per coin: {last_value_coagent}\n"
58
+ " - You proposed: {last_split_agent} coins\n"
59
+ " - You earned: {last_points_agent} points\n"
60
+ " - {other_agent} proposed: {last_split_coagent} coins\n"
61
+ " - {other_agent} earned: {last_points_coagent} points\n"
62
+ " - Round complete.\n"
63
+ )
64
+ self.send_split_prompt = (
65
+ "Submit your proposal\n"
66
+ "Respond with <coins_to_self> x </coins_to_self> where x is an integer in [0, 10]."
67
+ )
68
+ self.wait_for_message_prompt = "Wait for {other_agent} to send a message..."
69
+ self.last_message_prompt = "{other_agent} said: {last_message}"
70
+ self.send_message_prompt = f"Send your message now in <message>...</message> (<={self.num_message_chars} chars)."
71
+
72
+ def get_message_regex(self, observation: TrustAndSplitRPSObs) -> str:
73
+ return rf"<message>[\s\S]{{0,{self.num_message_chars}}}</message>"
74
+
75
+ def get_split_regex(self, observation: TrustAndSplitRPSObs) -> str:
76
+ return r"<coins_to_self> ?(10|[0-9]) ?</coins_to_self>"
77
+
78
+ def get_split_action(
79
+ self, policy_output: str, observation: TrustAndSplitRPSObs
80
+ ) -> Split:
81
+ import re as _re
82
+
83
+ m = _re.search(r"<coins_to_self> ?(10|[0-9]) ?</coins_to_self>", policy_output)
84
+ coins_int = int(m.group(1)) if m else int(policy_output)
85
+ return Split(items_given_to_self={"coins": coins_int})
src_code_for_reproducibility/markov_games/negotiation/tas_rps_simulation.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Trust-and-Split simulation.
3
+
4
+ This environment models a simple bargaining game over 10 coins with messaging.
5
+ Agents are assigned rock/paper/scissors hands, with the winner getting value 10 per coin
6
+ and the loser getting value 1 per coin. Agents alternate sending messages for a fixed
7
+ number of turns per round and then each submits a split proposal indicating how many
8
+ coins they keep for themselves. Rewards are proportional if the proposed totals exceed 10.
9
+ """
10
+
11
+ import copy
12
+ from dataclasses import dataclass
13
+ from typing import Any, Dict, List, Literal, Tuple
14
+
15
+ from numpy.random import default_rng
16
+
17
+ from mllm.markov_games.negotiation.nego_simulation import (
18
+ Message,
19
+ NegotiationObs,
20
+ NegotiationSimulation,
21
+ NegotiationState,
22
+ Split,
23
+ compute_tas_style_rewards,
24
+ )
25
+ from mllm.markov_games.rollout_tree import SimulationStepLog
26
+
27
+ AgentId = str
28
+
29
+
30
+ def _get_rps_winner(
31
+ hand1: Literal["rock", "paper", "scissors"],
32
+ hand2: Literal["rock", "paper", "scissors"],
33
+ ) -> Literal["rock", "paper", "scissors"]:
34
+ """Determine winner of rock-paper-scissors between two hands."""
35
+ if hand1 == hand2:
36
+ raise ValueError("Hands should be different")
37
+ if (
38
+ (hand1 == "rock" and hand2 == "scissors")
39
+ or (hand1 == "paper" and hand2 == "rock")
40
+ or (hand1 == "scissors" and hand2 == "paper")
41
+ ):
42
+ return hand1
43
+ else:
44
+ return hand2
45
+
46
+
47
+ @dataclass
48
+ class TrustAndSplitRPSState(NegotiationState):
49
+ hands: Dict[
50
+ AgentId, Literal["rock", "paper", "scissors"]
51
+ ] # rock, paper, or scissors
52
+ previous_hands: Dict[AgentId, Literal["rock", "paper", "scissors"]] | None
53
+
54
+
55
+ @dataclass
56
+ class TrustAndSplitRPSObs(NegotiationObs):
57
+ hand: Literal["rock", "paper", "scissors"]
58
+ last_hand_agent: Literal["rock", "paper", "scissors"] | None
59
+ last_hand_coagent: Literal["rock", "paper", "scissors"] | None
60
+
61
+
62
+ class TrustAndSplitRPSSimulation(NegotiationSimulation):
63
+ def __init__(
64
+ self,
65
+ *args,
66
+ **kwargs,
67
+ ):
68
+ super().__init__(*args, **kwargs)
69
+
70
+ def _sample_hands_and_values(
71
+ self,
72
+ ) -> Tuple[Dict[AgentId, str], Dict[AgentId, float]]:
73
+ # Assign different hands to each agent
74
+ hands = ["rock", "paper", "scissors"]
75
+ hand1, hand2 = self.rng.choice(hands, size=2, replace=False)
76
+
77
+ agent_hands = {self.agent_ids[0]: hand1, self.agent_ids[1]: hand2}
78
+
79
+ # Determine winner and assign values
80
+ winner = _get_rps_winner(hand1, hand2)
81
+ values = {}
82
+ for agent_id in self.agent_ids:
83
+ if agent_hands[agent_id] == winner:
84
+ values[agent_id] = 10.0 # Winner gets value 10
85
+ else:
86
+ values[agent_id] = 1.0 # Loser gets value 1
87
+
88
+ return agent_hands, values
89
+
90
+ def set_new_round_of_variant(self):
91
+ self.state.previous_values = copy.deepcopy(self.state.values)
92
+ self.state.previous_hands = copy.deepcopy(self.state.hands)
93
+ new_hands, new_values = self._sample_hands_and_values()
94
+ self.state.hands = new_hands
95
+ self.state.values = new_values
96
+ # Quantities are constant in TAS
97
+ self.state.quantities = {"coins": 10}
98
+ self.state.split_phase = False
99
+
100
+ def get_info_of_variant(
101
+ self, state: NegotiationState, actions: Dict[AgentId, Any]
102
+ ) -> Dict[str, Any]:
103
+ return {
104
+ "quantities": copy.deepcopy(state.quantities),
105
+ "hands": copy.deepcopy(state.hands),
106
+ "values": copy.deepcopy(state.values),
107
+ "previous_hands": copy.deepcopy(state.previous_hands),
108
+ "previous_values": copy.deepcopy(state.previous_values),
109
+ "splits": copy.deepcopy(state.splits),
110
+ }
111
+
112
+ def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]:
113
+ return compute_tas_style_rewards(self.agent_ids, self.state.values, splits, 10)
114
+
115
+ def get_obs_agent(self, agent_id):
116
+ """Returns observation for agent_id"""
117
+ other_id = self._other(agent_id)
118
+ last_value_coagent = (
119
+ None
120
+ if self.state.previous_values is None
121
+ else self.state.previous_values.get(other_id)
122
+ )
123
+ last_hand_coagent = (
124
+ None
125
+ if self.state.previous_hands is None
126
+ else self.state.previous_hands.get(other_id)
127
+ )
128
+ last_points_coagent = (
129
+ None
130
+ if self.state.previous_points is None
131
+ else round(self.state.previous_points.get(other_id), 1)
132
+ )
133
+ last_value_agent = (
134
+ None
135
+ if self.state.previous_values is None
136
+ else self.state.previous_values.get(agent_id)
137
+ )
138
+ last_hand_agent = (
139
+ None
140
+ if self.state.previous_hands is None
141
+ else self.state.previous_hands.get(agent_id)
142
+ )
143
+ last_points_agent = (
144
+ None
145
+ if self.state.previous_points is None
146
+ else round(self.state.previous_points.get(agent_id), 1)
147
+ )
148
+ last_split_coagent = None
149
+ last_split_agent = None
150
+ if self.state.previous_splits is not None:
151
+ last_split_coagent = self.state.previous_splits[
152
+ other_id
153
+ ].items_given_to_self["coins"]
154
+ last_split_agent = self.state.previous_splits[agent_id].items_given_to_self[
155
+ "coins"
156
+ ]
157
+ obs = TrustAndSplitRPSObs(
158
+ round_nb=self.state.round_nb,
159
+ last_message=self.state.last_message,
160
+ quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round,
161
+ current_agent=self.state.current_agent,
162
+ other_agent=other_id,
163
+ quantities={"coins": 10},
164
+ item_types=self.item_types,
165
+ value=self.state.values[agent_id],
166
+ split_phase=self.state.split_phase,
167
+ last_split_agent=last_split_agent,
168
+ last_value_agent=last_value_agent,
169
+ last_points_agent=last_points_agent,
170
+ last_split_coagent=last_split_coagent,
171
+ last_value_coagent=last_value_coagent,
172
+ last_points_coagent=last_points_coagent,
173
+ hand=self.state.hands[agent_id],
174
+ last_hand_coagent=last_hand_coagent,
175
+ last_hand_agent=last_hand_agent,
176
+ )
177
+ return obs
178
+
179
+ def get_state(self):
180
+ return self.state
181
+
182
+ def get_safe_copy(self):
183
+ """Return a safe copy of the simulation."""
184
+ simulation_copy = copy.copy(self)
185
+ simulation_copy.state = copy.deepcopy(self.state)
186
+ return simulation_copy
187
+
188
+ def reset(self):
189
+ """Initialize and return initial observations"""
190
+ # Decide starting agent alternating across resets for determinism
191
+ start_agent = self.agent_ids[self._starting_agent_index]
192
+ hands, values = self._sample_hands_and_values()
193
+ self.state = TrustAndSplitRPSState(
194
+ round_nb=0,
195
+ last_message="",
196
+ current_agent=start_agent,
197
+ quantities={"coins": 10},
198
+ values=values,
199
+ splits={aid: None for aid in self.agent_ids},
200
+ nb_messages_sent={aid: 0 for aid in self.agent_ids},
201
+ previous_values=None,
202
+ previous_splits=None,
203
+ previous_points=None,
204
+ split_phase=False,
205
+ hands=hands,
206
+ previous_hands=None,
207
+ )
208
+ return self.get_obs()