Muqeeth commited on
Commit
bcf4380
·
verified ·
1 Parent(s): 7f1ed79

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. src_code_for_reproducibility/__init__.py +0 -0
  2. src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc +0 -0
  3. src_code_for_reproducibility/chat_utils/__pycache__/chat_turn.cpython-312.pyc +0 -0
  4. src_code_for_reproducibility/chat_utils/__pycache__/template_specific.cpython-312.pyc +0 -0
  5. src_code_for_reproducibility/chat_utils/apply_template.py +84 -0
  6. src_code_for_reproducibility/chat_utils/template_specific.py +109 -0
  7. src_code_for_reproducibility/docs/source/contributing.rst +0 -0
  8. src_code_for_reproducibility/docs/source/environments/diplomacy.rst +459 -0
  9. src_code_for_reproducibility/docs/source/environments/dond.rst +410 -0
  10. src_code_for_reproducibility/docs/source/environments/ipd.rst +411 -0
  11. src_code_for_reproducibility/docs/source/index.rst +22 -0
  12. src_code_for_reproducibility/docs/source/launch.rst +0 -0
  13. src_code_for_reproducibility/docs/source/media/runbatch.png +0 -0
  14. src_code_for_reproducibility/docs/source/src.environments.dond.dond_return_funcs.rst +7 -0
  15. src_code_for_reproducibility/docs/source/src.environments.dond.dond_statistics_funcs.rst +7 -0
  16. src_code_for_reproducibility/docs/source/src.environments.env_imports.rst +7 -0
  17. src_code_for_reproducibility/docs/source/src.models.hf_agent.rst +7 -0
  18. src_code_for_reproducibility/docs/source/src.models.new_local_llm.rst +7 -0
  19. src_code_for_reproducibility/docs/source/src.models.oai_agent.rst +7 -0
  20. src_code_for_reproducibility/docs/source/src.models.server_llm.rst +7 -0
  21. src_code_for_reproducibility/docs/source/src.models.vllm_worker_wrap.rst +7 -0
  22. src_code_for_reproducibility/docs/source/src.run.rst +7 -0
  23. src_code_for_reproducibility/docs/source/src.utils.extra_stats.rst +7 -0
  24. src_code_for_reproducibility/docs/source/src.utils.inherit_args.rst +7 -0
  25. src_code_for_reproducibility/docs/source/src.utils.log_statistics.rst +7 -0
  26. src_code_for_reproducibility/docs/source/src.utils.model_to_cpu.rst +7 -0
  27. src_code_for_reproducibility/docs/source/src.utils.quick_stats.rst +7 -0
  28. src_code_for_reproducibility/docs/source/usage.rst +0 -0
  29. src_code_for_reproducibility/markov_games/__pycache__/group_timesteps.cpython-312.pyc +0 -0
  30. src_code_for_reproducibility/markov_games/__pycache__/run_markov_games.cpython-312.pyc +0 -0
  31. src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-312.pyc +0 -0
  32. src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging_for_training.py +0 -0
  33. src_code_for_reproducibility/markov_games/markov_game.py +208 -0
  34. src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_simulation.cpython-312.pyc +0 -0
  35. src_code_for_reproducibility/markov_games/negotiation/nego_hard_coded_policies.py +64 -0
  36. src_code_for_reproducibility/markov_games/negotiation/negotiation_statistics.py +244 -0
  37. src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc +0 -0
  38. src_code_for_reproducibility/training/__pycache__/__init__.cpython-312.pyc +0 -0
  39. src_code_for_reproducibility/training/__pycache__/produce_training_stats.cpython-312.pyc +0 -0
  40. src_code_for_reproducibility/training/__pycache__/tally_metrics.cpython-312.pyc +0 -0
  41. src_code_for_reproducibility/training/__pycache__/tally_tokenwise.cpython-312.pyc +0 -0
  42. src_code_for_reproducibility/training/__pycache__/tokenize_chats.cpython-312.pyc +0 -0
  43. src_code_for_reproducibility/training/__pycache__/trainer_ad_align.cpython-312.pyc +0 -0
  44. src_code_for_reproducibility/training/__pycache__/trainer_common.cpython-312.pyc +0 -0
  45. src_code_for_reproducibility/training/tally_tokenwise.py +276 -0
  46. src_code_for_reproducibility/training/tokenize_chats.py +128 -0
  47. src_code_for_reproducibility/training/trainer_sum_rewards.py +127 -0
  48. src_code_for_reproducibility/utils/__pycache__/get_coagent_id.cpython-312.pyc +0 -0
  49. src_code_for_reproducibility/utils/__pycache__/resource_context.cpython-312.pyc +0 -0
  50. src_code_for_reproducibility/utils/get_coagent_id.py +4 -0
src_code_for_reproducibility/__init__.py ADDED
File without changes
src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc ADDED
Binary file (3.92 kB). View file
 
src_code_for_reproducibility/chat_utils/__pycache__/chat_turn.cpython-312.pyc ADDED
Binary file (1.32 kB). View file
 
src_code_for_reproducibility/chat_utils/__pycache__/template_specific.cpython-312.pyc ADDED
Binary file (4.24 kB). View file
 
src_code_for_reproducibility/chat_utils/apply_template.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from mllm.chat_utils.chat_turn import ChatTurn
4
+ from mllm.chat_utils.template_specific import (
5
+ custom_gemma3_template,
6
+ custom_llama3_template,
7
+ custom_qwen2_template,
8
+ custom_qwen3_template,
9
+ gemma3_assistant_postfix,
10
+ qwen2_assistant_postfix,
11
+ qwen3_assistant_postfix,
12
+ )
13
+
14
+
15
+ def get_custom_chat_template(tokenizer) -> str:
16
+ """
17
+ Get the chat template for the tokenizer.
18
+ """
19
+ if "qwen2" in tokenizer.name_or_path.lower():
20
+ return custom_qwen2_template
21
+ elif "llama" in tokenizer.name_or_path.lower():
22
+ return custom_llama3_template
23
+ elif "qwen3" in tokenizer.name_or_path.lower():
24
+ return custom_qwen3_template
25
+ elif "gemma" in tokenizer.name_or_path.lower():
26
+ return custom_gemma3_template
27
+ else:
28
+ raise ValueError(f"Tokenizer {tokenizer.name_or_path} not supported")
29
+
30
+
31
+ def get_custom_assistant_postfix(tokenizer) -> torch.Tensor:
32
+ """
33
+ Get the custom assistant postfix for the tokenizer.
34
+ """
35
+ if "qwen2" in tokenizer.name_or_path.lower():
36
+ return qwen2_assistant_postfix
37
+ elif "qwen3" in tokenizer.name_or_path.lower():
38
+ return qwen3_assistant_postfix
39
+ elif "gemma" in tokenizer.name_or_path.lower():
40
+ return gemma3_assistant_postfix
41
+ return torch.tensor([], dtype=torch.long)
42
+
43
+
44
+ def tokenize_chats(chats: list[ChatTurn], tokenizer, enable_thinking) -> None:
45
+ """
46
+ Set the chat_template_token_ids for each chat turn.
47
+ # TODO: use engine tokens if available
48
+ """
49
+ custom_template = get_custom_chat_template(tokenizer)
50
+ custom_assistant_postfix: torch.Tensor = get_custom_assistant_postfix(tokenizer)
51
+ for i, chat in enumerate(chats):
52
+ if chat.chat_template_token_ids is None:
53
+ if chat.role == "user":
54
+ next_chat = chats[i + 1] if i + 1 < len(chats) else None
55
+ add_generation_prompt = True
56
+ if next_chat and next_chat.role == "user":
57
+ add_generation_prompt = False
58
+ encoded_chat = tokenizer.apply_chat_template(
59
+ [chat],
60
+ return_tensors="pt",
61
+ chat_template=custom_template,
62
+ add_generation_prompt=add_generation_prompt,
63
+ add_system_prompt=True if i == 0 else False,
64
+ enable_thinking=enable_thinking,
65
+ ).flatten()
66
+ previous_chat = chats[i - 1] if i > 0 else None
67
+ if previous_chat and previous_chat.role == "assistant":
68
+ encoded_chat = torch.cat([custom_assistant_postfix, encoded_chat])
69
+ elif chat.role == "assistant":
70
+ encoded_chat = chat.out_token_ids
71
+ chat.chat_template_token_ids = encoded_chat
72
+
73
+
74
+ def chat_turns_to_token_ids(
75
+ chats: list[ChatTurn], tokenizer, enable_thinking
76
+ ) -> list[int]:
77
+ """
78
+ Tokenize the chat turns and set the chat_template_token_ids for each chat turn.
79
+ """
80
+ tokenize_chats(chats=chats, tokenizer=tokenizer, enable_thinking=enable_thinking)
81
+ token_ids = []
82
+ for chat in chats:
83
+ token_ids.append(chat.chat_template_token_ids)
84
+ return torch.cat(token_ids)
src_code_for_reproducibility/chat_utils/template_specific.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import huggingface_hub
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+
5
+ custom_llama3_template = """
6
+ {%- if add_system_prompt %}
7
+ {{- '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|>' }}
8
+ {%- endif %}
9
+ {%- for message in messages %}
10
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}
11
+ {%- endfor %}
12
+
13
+ {%- if add_generation_prompt %}
14
+ {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
15
+ {%- endif %}
16
+ """
17
+
18
+ qwen2_assistant_postfix = (
19
+ AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
20
+ .encode("\n", return_tensors="pt")
21
+ .flatten()
22
+ )
23
+ qwen3_assistant_postfix = (
24
+ AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
25
+ .encode("\n", return_tensors="pt")
26
+ .flatten()
27
+ )
28
+ gemma3_assistant_postfix = (
29
+ AutoTokenizer.from_pretrained("google/gemma-3-4b-it")
30
+ .encode("\n", return_tensors="pt")
31
+ .flatten()
32
+ )
33
+ custom_qwen2_template = """
34
+ {%- if add_system_prompt %}
35
+ {{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }}
36
+ {%- endif %}
37
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
38
+ {%- for message in messages %}
39
+ {%- if message.content is string %}
40
+ {%- set content = message.content %}
41
+ {%- else %}
42
+ {%- set content = '' %}
43
+ {%- endif %}
44
+ {%- if (message.role == "user") %}
45
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
46
+ {%- elif message.role == "assistant" %}
47
+ {%- set reasoning_content = '' %}
48
+ {%- if message.reasoning_content is string %}
49
+ {%- set reasoning_content = message.reasoning_content %}
50
+ {%- else %}
51
+ {%- if '</think>' in content %}
52
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
53
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
54
+ {%- endif %}
55
+ {%- endif %}
56
+ {%- if loop.index0 > ns.last_query_index %}
57
+ {%- if reasoning_content %}
58
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
59
+ {%- else %}
60
+ {{- '<|im_start|>' + message.role + '\n' + content }}
61
+ {%- endif %}
62
+ {%- else %}
63
+ {{- '<|im_start|>' + message.role + '\n' + content }}
64
+ {%- endif %}
65
+ {{- '<|im_end|>\n' }}
66
+ {%- endif %}
67
+ {%- endfor %}
68
+ {%- if add_generation_prompt %}
69
+ {{- '<|im_start|>assistant\n' }}
70
+ {%- endif %}
71
+ """
72
+
73
+ custom_qwen3_template = """
74
+ {%- for message in messages %}
75
+ {%- if message.content is string %}
76
+ {%- set content = message.content %}
77
+ {%- else %}
78
+ {%- set content = '' %}
79
+ {%- endif %}
80
+ {%- if (message.role == "user") %}
81
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
82
+ {%- elif message.role == "assistant" %}
83
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
84
+ {%- endif %}
85
+ {%- endfor %}
86
+ {%- if add_generation_prompt %}
87
+ {{- '<|im_start|>assistant\n' }}
88
+ {%- if enable_thinking is defined and enable_thinking is false %}
89
+ {{- '<think>\n\n</think>\n\n' }}
90
+ {%- endif %}
91
+ {%- endif %}
92
+ """
93
+
94
+ custom_gemma3_template = """
95
+ {%- if add_system_prompt %}
96
+ {{- bos_token -}}
97
+ {%- endif %}
98
+ {%- for message in messages -%}
99
+ {%- if message['role'] == 'assistant' -%}
100
+ {%- set role = 'model' -%}
101
+ {%- else -%}
102
+ {%- set role = message['role'] -%}
103
+ {%- endif -%}
104
+ {{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}
105
+ {%- endfor -%}
106
+ {%- if add_generation_prompt -%}
107
+ {{ '<start_of_turn>model\n' }}
108
+ {%- endif -%}
109
+ """
src_code_for_reproducibility/docs/source/contributing.rst ADDED
File without changes
src_code_for_reproducibility/docs/source/environments/diplomacy.rst ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ =================
2
+ Diplomacy
3
+ =================
4
+
5
+ The Diplomacy environment provides a multi-agent negotiation interface for the classic board game Diplomacy,
6
+ based on DeepMind's implementation. This document describes the API for interacting with the Diplomacy environment
7
+ and its associated agent handler.
8
+
9
+ Overview
10
+ --------
11
+
12
+ Diplomacy is a strategic board game set in Europe before World War I, where players control one of seven European powers
13
+ and negotiate with each other to gain control of supply centers. The game is played in turns, with each turn consisting
14
+ of movement phases, retreat phases, and build phases.
15
+
16
+ Our implementation adapts DeepMind's Diplomacy code to the Multi-Agent Negotiation Environment standard, allowing it
17
+ to be used with LLM agents through a text-based interface.
18
+
19
+ Game Rules
20
+ ----------
21
+
22
+ ### Game Board and Powers
23
+
24
+ Diplomacy is played on a map of Europe divided into provinces. The game features seven Great Powers that players can control:
25
+
26
+ - England (blue)
27
+ - France (light blue)
28
+ - Germany (black)
29
+ - Italy (green)
30
+ - Austria-Hungary (red)
31
+ - Russia (white)
32
+ - Turkey (yellow)
33
+
34
+ Each power begins with three supply centers (except Russia, which starts with four) and an equal number of units.
35
+
36
+ ### Units and Movement
37
+
38
+ There are two types of units in Diplomacy:
39
+ - **Armies (A)**: Can move to adjacent land provinces or be convoyed across water by fleets
40
+ - **Fleets (F)**: Can move to adjacent coastal provinces and sea regions
41
+
42
+ During movement phases, each unit can execute one of these orders:
43
+ - **Hold**: The unit remains in its current province (e.g., "A PAR H")
44
+ - Format: [Unit Type] [Province] H
45
+ - Example: "A PAR H" means "Army in Paris holds its position"
46
+
47
+ - **Move**: The unit attempts to move to an adjacent province (e.g., "A PAR - BUR")
48
+ - Format: [Unit Type] [Current Province] - [Destination Province]
49
+ - Example: "A PAR - BUR" means "Army in Paris moves to Burgundy"
50
+ - Example: "F BRE - ENG" means "Fleet in Brest moves to the English Channel"
51
+
52
+ - **Support**: The unit supports another unit's move or hold (e.g., "A PAR S A MAR - BUR")
53
+ - Format for supporting a move: [Unit Type] [Province] S [Unit Type] [Province] - [Destination]
54
+ - Format for supporting a hold: [Unit Type] [Province] S [Unit Type] [Province]
55
+ - Example: "A PAR S A MAR - BUR" means "Army in Paris supports the Army in Marseille's move to Burgundy"
56
+ - Example: "F LON S F NTH" means "Fleet in London supports the Fleet in North Sea holding its position"
57
+
58
+ - **Convoy**: A fleet can convoy an army across water (e.g., "F ENG C A LON - BRE")
59
+ - Format: [Fleet] [Sea Province] C [Army] [Coastal Province] - [Coastal Province]
60
+ - Example: "F ENG C A LON - BRE" means "Fleet in English Channel convoys the Army in London to Brest"
61
+
62
+ All orders are executed simultaneously, and conflicts are resolved based on strength (number of supporting units).
63
+
64
+ ### Common Province Abbreviations
65
+
66
+ Diplomacy uses three-letter abbreviations for provinces. Some common ones include:
67
+ - **PAR**: Paris
68
+ - **LON**: London
69
+ - **BER**: Berlin
70
+ - **MUN**: Munich
71
+ - **BUR**: Burgundy
72
+ - **MAR**: Marseilles
73
+ - **BRE**: Brest
74
+ - **ENG**: English Channel
75
+ - **NTH**: North Sea
76
+ - **VIE**: Vienna
77
+ - **ROM**: Rome
78
+ - **VEN**: Venice
79
+ - **MOW**: Moscow
80
+ - **CON**: Constantinople
81
+
82
+ ### Example: Movement and Conflicts
83
+
84
+ For example, if France orders "A PAR - BUR" and Germany orders "A MUN - BUR", neither move succeeds as they have equal strength. However, if France also orders "A MAR S A PAR - BUR", then the French army from Paris would successfully move to Burgundy with strength of 2 against Germany's strength of 1.
85
+
86
+ ### Turn Structure
87
+
88
+ A game year consists of five phases:
89
+ 1. **Spring Movement**: All powers submit orders for their units
90
+ 2. **Spring Retreat**: Units dislodged in the movement phase must retreat or be disbanded
91
+ 3. **Fall Movement**: Another round of movement orders
92
+ 4. **Fall Retreat**: Retreat orders for dislodged units
93
+ 5. **Winter Adjustment**: Powers gain or lose units based on the number of supply centers they control
94
+
95
+ ### Supply Centers and Building
96
+
97
+ Supply centers (marked on the map) are key to victory. When a power occupies a supply center during a Fall turn, they gain control of it. During the Winter Adjustment phase:
98
+ - If you control more supply centers than you have units, you can build new units in your home supply centers
99
+ - If you control fewer supply centers than you have units, you must remove excess units
100
+
101
+ ### Example: Building and Removing Units
102
+
103
+ If France controls 5 supply centers but only has 4 units, during the Winter phase they can build one new unit in an unoccupied home supply center (Paris, Marseilles, or Brest). Conversely, if France controls only 3 supply centers but has 4 units, they must remove one unit of their choice.
104
+
105
+ ### Negotiation
106
+
107
+ A critical component of Diplomacy is the negotiation between players. Before submitting orders, players can communicate freely to form alliances, coordinate attacks, or mislead opponents. These negotiations are not binding, and betrayal is a common strategy.
108
+
109
+ ### Example: Alliance and Betrayal
110
+
111
+ England and France might agree to an alliance against Germany, with England promising to support France's move into Belgium. However, England could secretly order their fleet to move into Belgium themselves or support a German move instead.
112
+
113
+ ### Victory Conditions
114
+
115
+ The game ends when one power controls 18 or more supply centers (majority of the 34 total centers), or when players agree to a draw. In tournament settings, games may also end after a predetermined number of game years.
116
+
117
+ DiplomacyEnv
118
+ ------------
119
+
120
+ The ``DiplomacyEnv`` class provides an interface to the Diplomacy game environment that follows the Multi-Agent
121
+ Negotiation Environment standard.
122
+
123
+ .. code-block:: python
124
+
125
+ class DiplomacyEnv:
126
+ """
127
+ Multi-Agent Negotiation Environment for Diplomacy, adapting Deepmind's implementation
128
+ to the MarlEnvironment standard.
129
+ """
130
+ def __init__(self,
131
+ initial_state: Optional[DiplomacyState] = None,
132
+ max_turns: int = 100,
133
+ points_per_supply_centre: bool = True,
134
+ forced_draw_probability: float = 0.0,
135
+ min_years_forced_draw: int = 35):
136
+ """Initialize the Diplomacy environment.
137
+
138
+ Args:
139
+ initial_state: Initial DiplomacyState (optional)
140
+ max_turns: Maximum number of turns in the game
141
+ points_per_supply_centre: Whether to award points per supply center in case of a draw
142
+ forced_draw_probability: Probability of forcing a draw after min_years_forced_draw
143
+ min_years_forced_draw: Minimum years before considering a forced draw
144
+ """
145
+ # ...
146
+
147
+ def reset(self):
148
+ """Reset the environment to an initial state and return the initial observation.
149
+
150
+ Returns:
151
+ observation (dict): A dictionary where keys are agent identifiers and values are observations.
152
+ Each observation contains:
153
+ - board_state: Current state of the board
154
+ - current_season: Current season in the game
155
+ - player_index: Index of the player's power
156
+ - possible_actions: List of possible actions in DeepMind's format
157
+ - human_readable_actions: List of human-readable action descriptions
158
+ - supply_centers: List of supply centers owned by the player
159
+ - units: List of units owned by the player
160
+ - year: Current year in the game
161
+ """
162
+ # ...
163
+
164
+ def step(self, actions):
165
+ """Take a step in the environment using the provided actions.
166
+
167
+ Args:
168
+ actions (dict): A dictionary where keys are agent identifiers and values are actions.
169
+ Actions can be:
170
+ - List of integer actions in DeepMind's format
171
+ - List of string actions in text format (e.g., "A MUN - BER")
172
+
173
+ Returns:
174
+ observations (dict): A dictionary where keys are agent identifiers and values are observations.
175
+ Each observation has the same structure as in reset().
176
+ done (bool): Whether the episode has ended.
177
+ info (dict): Additional information about the environment, including:
178
+ - turn: Current turn number
179
+ - returns: Game returns if the game is done, otherwise None
180
+ - waiting_for: List of agents that still need to provide actions (if not all actions are provided)
181
+ """
182
+ # ...
183
+
184
+ def get_log_info(self):
185
+ """Get additional information about the environment for logging.
186
+
187
+ Returns:
188
+ log_info (dict): Information about the environment required to log the game, including:
189
+ - power_names: List of power names
190
+ - game_history: History of the game
191
+ - current_turn: Current turn number
192
+ - current_season: Current season name
193
+ - supply_centers: Dictionary mapping power names to supply center counts
194
+ """
195
+ # ...
196
+
197
+ def render(self):
198
+ """Render the current state of the environment.
199
+
200
+ Displays a visualization of the current game state.
201
+ """
202
+ # ...
203
+
204
+ def close(self):
205
+ """Perform any necessary cleanup."""
206
+ # ...
207
+
208
+
209
+ Key Implementation Details
210
+ ~~~~~~~~~~~~~~~~~~~~~~~~~
211
+
212
+ The ``DiplomacyEnv`` class implements several key features:
213
+
214
+ 1. **Multi-Agent Support**: The environment tracks multiple agents (powers) and manages their interactions.
215
+
216
+ 2. **Turn-Based Gameplay**: The environment enforces the turn structure of Diplomacy, including different phases.
217
+
218
+ 3. **Action Processing**: The environment can handle actions in both text format and DeepMind's integer format.
219
+
220
+ 4. **Observation Generation**: The environment generates detailed observations for each agent, including board state, supply centers, and possible actions.
221
+
222
+ 5. **Game Termination**: The environment tracks game termination conditions, including supply center victory and maximum turn limits.
223
+
224
+ Observation Structure
225
+ ~~~~~~~~~~~~~~~~~~~~
226
+
227
+ Each agent receives an observation dictionary with the following structure:
228
+
229
+ .. code-block:: python
230
+
231
+ {
232
+ "board_state": np.ndarray, # Board state representation
233
+ "current_season": int, # Season index (0-4)
234
+ "player_index": int, # Index of the player's power (0-6)
235
+ "possible_actions": [int], # List of possible actions in DeepMind's format
236
+ "human_readable_actions": [str], # List of human-readable action descriptions
237
+ "supply_centers": [str], # List of supply centers owned by the player
238
+ "units": [dict], # List of units owned by the player
239
+ "year": int # Current year in the game
240
+ }
241
+
242
+ Action Structure
243
+ ~~~~~~~~~~~~~~~
244
+
245
+ Actions can be provided in two formats:
246
+
247
+ 1. **Text Format**: String actions like ``"A MUN - BER"`` or ``"F NTH C A LON - BEL"``.
248
+
249
+ 2. **Integer Format**: Lists of integers corresponding to DeepMind's action representation.
250
+
251
+ The environment will convert text actions to the internal format as needed.
252
+
253
+ DiplomacyAgent
254
+ --------------
255
+
256
+ The ``DiplomacyAgent`` class implements the agent handler interface for Diplomacy, processing observations from the environment and generating actions through an LLM.
257
+
258
+ .. code-block:: python
259
+
260
+ class DiplomacyAgent:
261
+ """
262
+ Agent handler for Diplomacy, implementing the AgentState interface
263
+ for the multi-agent negotiation standard.
264
+ """
265
+
266
+ def __init__(self,
267
+ power_name: str,
268
+ use_text_interface: bool = True,
269
+ system_prompt: Optional[str] = None):
270
+ """Initialize the Diplomacy agent handler.
271
+
272
+ Args:
273
+ power_name: Name of the power this agent controls
274
+ use_text_interface: Whether to use text-based interface (vs. structured)
275
+ system_prompt: Optional system prompt to use for the LLM
276
+ """
277
+ # ...
278
+
279
+ def step(self, observation_from_env, policy_output=None):
280
+ """Update the agent state based on the observation and action.
281
+
282
+ Args:
283
+ observation_from_env: The observation from the environment, with structure:
284
+ - board_state: Current state of the board
285
+ - current_season: Current season in the game
286
+ - player_index: Index of the player's power
287
+ - possible_actions: List of possible actions
288
+ - human_readable_actions: List of human-readable action descriptions
289
+ - supply_centers: List of supply centers owned by the player
290
+ - units: List of units owned by the player
291
+ - year: Current year in the game
292
+
293
+ policy_output: The output of the policy (LLM response), or None for initial prompt
294
+
295
+ Returns:
296
+ policy_id (str): The policy identifier ("llm_policy")
297
+ policy_input (dict): The input to the policy, with structure:
298
+ - messages: List of conversation messages in the format:
299
+ [{"role": "system", "content": "..."},
300
+ {"role": "user", "content": "..."}]
301
+ action: The official action to be sent to the environment, or None if not ready
302
+ done (bool): Whether the LLM action is ready to be sent to the environment
303
+ info (dict): Additional information about the agent:
304
+ - valid_action: Whether the extracted action is valid
305
+ """
306
+ # ...
307
+
308
+ def get_log_info(self):
309
+ """Get information about the agent required to log a trajectory.
310
+
311
+ Returns:
312
+ log_info (dict): Information about the agent required to log a trajectory:
313
+ - power_name: Name of the power this agent controls
314
+ - conversation_history: List of conversation messages
315
+ - current_action: The current action, if any
316
+ """
317
+ # ...
318
+
319
+ def render(self):
320
+ """Render the current state of the agent.
321
+
322
+ Displays the agent's current state, including conversation history.
323
+ """
324
+ # ...
325
+
326
+ def close(self):
327
+ """Perform any necessary cleanup."""
328
+ # ...
329
+
330
+
331
+ Key Implementation Details
332
+ ~~~~~~~~~~~~~~~~~~~~~~~~~
333
+
334
+ The ``DiplomacyAgent`` class implements several key features:
335
+
336
+ 1. **LLM Interaction**: The agent generates prompts for an LLM and processes the LLM's responses to extract actions.
337
+
338
+ 2. **Conversation Management**: The agent maintains a conversation history for coherent interactions with the LLM.
339
+
340
+ 3. **Action Validation**: The agent validates extracted actions against the set of possible actions provided by the environment.
341
+
342
+ 4. **Error Handling**: The agent generates clarification prompts when invalid actions are detected.
343
+
344
+ 5. **Text-Based Interface**: The agent formats game state information into human-readable text for the LLM.
345
+
346
+ Prompt Structure
347
+ ~~~~~~~~~~~~~~~
348
+
349
+ The agent generates prompts that include:
350
+
351
+ 1. **System Prompt**: Instructions and context for the LLM, explaining its role as a Diplomacy player.
352
+
353
+ 2. **Game State Description**: A text description of the current game state, including:
354
+ - Current year and season
355
+ - Supply centers owned
356
+ - Units controlled
357
+ - Possible actions
358
+
359
+ 3. **Action Request**: Instructions on how to format actions.
360
+
361
+ Example system prompt:
362
+
363
+ .. code-block:: text
364
+
365
+ You are playing the role of FRANCE in a game of Diplomacy.
366
+ Your goal is to control as many supply centers as possible.
367
+ You can negotiate with other players and form alliances, but remember that
368
+ these alliances are not binding. When you need to submit orders for your units,
369
+ write them in the correct format, with each order on a new line.
370
+
371
+ Example game state description:
372
+
373
+ .. code-block:: text
374
+
375
+ Year: 1901, Season: SPRING_MOVES
376
+ You are playing as FRANCE.
377
+ You currently control 3 supply centers: PAR, MAR, BRE.
378
+ Your units are: A PAR, A MAR, F BRE.
379
+
380
+ Please provide orders for your units. Here are your possible actions:
381
+ A PAR - BUR
382
+ A PAR - GAS
383
+ A PAR - PIC
384
+ A PAR H
385
+ ...
386
+
387
+ Submit your orders, one per line, in the format like: "A MUN - BER" or "F NTH C A LON - BEL"
388
+
389
+ Running Diplomacy Games
390
+ ----------------------
391
+
392
+ To run Diplomacy games with LLM agents, you can use the ``run_batched_matches`` function with the ``DiplomacyEnv`` and ``DiplomacyAgent`` classes:
393
+
394
+ .. code-block:: python
395
+
396
+ from mllm.environments.diplomacy.diplomacy_env import DiplomacyEnv
397
+ from mllm.environments.diplomacy.diplomacy_agent import DiplomacyAgent
398
+ from mllm.run_matches import run_batched_matches
399
+
400
+ # Create environment and agent handlers
401
+ env = DiplomacyEnv(max_turns=30)
402
+
403
+ agent_handlers = {
404
+ "AUSTRIA": DiplomacyAgent(power_name="AUSTRIA"),
405
+ "ENGLAND": DiplomacyAgent(power_name="ENGLAND"),
406
+ "FRANCE": DiplomacyAgent(power_name="FRANCE"),
407
+ "GERMANY": DiplomacyAgent(power_name="GERMANY"),
408
+ "ITALY": DiplomacyAgent(power_name="ITALY"),
409
+ "RUSSIA": DiplomacyAgent(power_name="RUSSIA"),
410
+ "TURKEY": DiplomacyAgent(power_name="TURKEY")
411
+ }
412
+
413
+ # Define policy mapping (mapping from policy IDs to actual policy functions)
414
+ policy_mapping = {
415
+ "llm_policy": my_llm_policy_function
416
+ }
417
+
418
+ # Run the game
419
+ game_results = run_batched_matches(
420
+ envs=[env],
421
+ agent_handlers_per_env=[agent_handlers],
422
+ policy_mapping=policy_mapping,
423
+ max_parallel_matches=1
424
+ )
425
+
426
+ # Process results
427
+ for result in game_results:
428
+ print(f"Game finished. Winner: {result['winner']}")
429
+ print(f"Supply centers: {result['supply_centers']}")
430
+
431
+ This setup allows you to run Diplomacy games with LLM agents using the Multi-Agent Negotiation Environment standard.
432
+
433
+ Limitations and Considerations
434
+ -----------------------------
435
+
436
+ 1. **Performance**: Processing observations and actions for seven powers using LLMs can be computationally intensive.
437
+
438
+ 2. **Action Parsing**: Extracting valid actions from LLM outputs may require sophisticated parsing and error handling.
439
+
440
+ 3. **Game Complexity**: Diplomacy is a complex game with many rules and edge cases, which may be challenging for LLMs to fully grasp.
441
+
442
+ 4. **Turn Duration**: Real Diplomacy games include negotiation phases of variable duration, which are not fully captured in this implementation.
443
+
444
+ 5. **Text Formatting**: The quality of LLM interactions depends heavily on the formatting and clarity of text prompts.
445
+
446
+ Advanced Usage
447
+ ------------
448
+
449
+ For advanced usage, you can customize:
450
+
451
+ 1. **System Prompts**: Modify agent behavior by providing custom system prompts.
452
+
453
+ 2. **Observation Processing**: Extend the observation processing to include additional information.
454
+
455
+ 3. **Action Parsing**: Implement more sophisticated action parsing for complex orders.
456
+
457
+ 4. **Visualization**: Add custom visualization methods to the environment's render function.
458
+
459
+ 5. **Logging**: Extend the logging capabilities to capture additional information about the game state.
src_code_for_reproducibility/docs/source/environments/dond.rst ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ =================
2
+ Deal or No Deal
3
+ =================
4
+
5
+ The Deal or No Deal (DoND) environment provides a multi-agent negotiation interface where players trade
6
+ items with different values. This document describes the API for interacting with the DoND environment
7
+ and its associated agent handler.
8
+
9
+ Overview
10
+ --------
11
+
12
+ Deal or No Deal is a negotiation game where two agents must agree on how to divide a set of items,
13
+ each of which has different values to each agent. The agents engage in a back-and-forth dialogue to
14
+ determine an allocation of the items, with each trying to maximize their own total value.
15
+
16
+ Our implementation follows the Multi-Agent Negotiation Environment standard, allowing it to be used
17
+ with LLM agents through a text-based interface.
18
+
19
+ Game Rules
20
+ ----------
21
+
22
+ ### Basic Structure
23
+
24
+ The core mechanics of Deal or No Deal are:
25
+
26
+ 1. Two agents negotiate over a set of items (e.g., books, balls, hats)
27
+ 2. Each item has:
28
+ - A specific quantity (how many of each item is available)
29
+ - A value for each agent (which may differ between agents)
30
+ 3. Agents take turns sending messages to negotiate how to split the items
31
+ 4. Once an agreement is reached, agents finalize the deal
32
+ 5. Points are awarded based on the value of items each agent receives
33
+
34
+ ### Detailed Gameplay
35
+
36
+ #### Setup Phase
37
+
38
+ The game begins with:
39
+ - A set of items (e.g., "book", "hat", "ball")
40
+ - Each item has a quantity (e.g., 6 books, 2 hats, 4 balls)
41
+ - Each agent has private values for each item (e.g., books might be worth 5 points to one agent but only 2 points to the other)
42
+ - Agents are assigned roles (starting negotiator and responding negotiator)
43
+
44
+ #### Negotiation Phase
45
+
46
+ 1. Agents take turns sending free-form text messages to each other
47
+ 2. Messages can include offers, counter-offers, questions, or strategic communication
48
+ 3. There is a maximum number of messages permitted (preventing endless negotiations)
49
+ 4. Either agent can propose to finalize an agreement at any time
50
+
51
+ For example:
52
+ - Agent 1: "I propose I get all the books and you get all the hats and balls."
53
+ - Agent 2: "That doesn't work for me. How about you get 3 books and I get 3 books, all the hats, and all the balls?"
54
+ - Agent 1: "Let me counter-offer: I get 4 books and 2 balls, you get 2 books, all hats, and 2 balls."
55
+
56
+ #### Finalization Phase
57
+
58
+ 1. When an agent wants to finalize a deal, they must specify the exact allocation:
59
+ - How many of each item they receive
60
+ - How many of each item the other agent receives
61
+ 2. The other agent must then either agree (by submitting the same allocation) or reject the finalization
62
+ 3. If both agents submit matching finalizations, the deal is executed
63
+ 4. If finalizations don't match, no agreement is reached, and both agents receive 0 points
64
+
65
+ #### Scoring
66
+
67
+ 1. Each agent's score is calculated based on the value of items they receive
68
+ 2. The formula is: Sum(quantity_of_item_i × value_of_item_i_to_agent)
69
+ 3. If no agreement is reached, both agents receive 0 points
70
+
71
+ ### Example Game
72
+
73
+ Let's walk through a simple example:
74
+
75
+ **Setup:**
76
+ - Items: Books (4), Hats (2), Balls (6)
77
+ - Agent 1 values: Books=5, Hats=1, Balls=2
78
+ - Agent 2 values: Books=3, Hats=6, Balls=1
79
+
80
+ **Negotiation (simplified):**
81
+ 1. Agent 1: "I would like all the books and balls. You can have the hats."
82
+ 2. Agent 2: "That doesn't work for me. Books are valuable. I propose I get all the hats and 2 books, you get 2 books and all the balls."
83
+ 3. Agent 1: "How about I get 3 books and all the balls, and you get 1 book and all the hats?"
84
+ 4. Agent 2: "I accept your proposal."
85
+
86
+ **Finalization:**
87
+ - Agent 1 submits: Agent 1 gets (Books: 3, Hats: 0, Balls: 6), Agent 2 gets (Books: 1, Hats: 2, Balls: 0)
88
+ - Agent 2 submits the same allocation, confirming agreement
89
+
90
+ **Scoring:**
91
+ - Agent 1 score: (3 books × 5) + (0 hats × 1) + (6 balls × 2) = 15 + 0 + 12 = 27 points
92
+ - Agent 2 score: (1 book × 3) + (2 hats × 6) + (0 balls × 1) = 3 + 12 + 0 = 15 points
93
+
94
+ ### Game Variations
95
+
96
+ The DoND environment supports several variations through configuration parameters:
97
+
98
+ #### Different Value Distributions
99
+
100
+ The environment offers multiple ways to assign values to items:
101
+
102
+ 1. **Standard Random Setup (dond_random_setup)**:
103
+ - Items have even-numbered quantities
104
+ - Each agent receives distinct random values for each item
105
+ - Values are drawn from a uniform distribution
106
+
107
+ 2. **Independent Random Values (independent_random_vals)**:
108
+ - Item quantities can be any number in the specified range
109
+ - Values for each agent are drawn independently
110
+ - Creates more varied negotiation scenarios
111
+
112
+ 3. **Bicameral Value Distribution (bicameral_vals_assignator)**:
113
+ - Creates a "high value" and "low value" distribution for each item
114
+ - Each agent values approximately half the items highly and half lowly
115
+ - Values are drawn from normal distributions with different means
116
+ - Creates scenarios with clear trade opportunities
117
+
118
+ #### Visibility Options
119
+
120
+ 1. **Finalization Visibility**:
121
+ - When enabled, both agents can see each other's finalization proposals
122
+ - When disabled, finalization proposals remain private until both are submitted
123
+
124
+ 2. **Other Values Visibility**:
125
+ - When enabled, agents can see each other's value functions
126
+ - When disabled, agents only know their own values
127
+ - Creates information asymmetry and richer negotiation dynamics
128
+
129
+ #### Game Modes
130
+
131
+ 1. **Cooperative Mode ("coop")**:
132
+ - Agents are encouraged to find mutually beneficial solutions
133
+ - Success is measured by the sum of both agents' scores
134
+
135
+ 2. **Competitive Mode ("comp")**:
136
+ - Agents aim to maximize their individual scores
137
+ - Creates more adversarial negotiations
138
+
139
+ #### Round Structure
140
+
141
+ 1. **Single Round**:
142
+ - One negotiation session between the same agents
143
+ - Simple evaluation of negotiation skills
144
+
145
+ 2. **Multiple Rounds**:
146
+ - Agents negotiate multiple times with different item setups
147
+ - Allows for learning and adaptation over time
148
+ - Roles can be swapped between rounds
149
+
150
+ DondEnv
151
+ ------------
152
+
153
+ The ``DondEnv`` class provides an interface to the Deal or No Deal environment that follows the Multi-Agent
154
+ Negotiation Environment standard.
155
+
156
+ .. code-block:: python
157
+
158
+ class DondEnv:
159
+ """
160
+ Multi-Agent Negotiation Environment for Deal or No Deal.
161
+ """
162
+ def __init__(
163
+ self,
164
+ agents,
165
+ mode="coop",
166
+ max_messages=None,
167
+ min_messages=None,
168
+ max_chars_per_message=None,
169
+ rounds_per_game=1,
170
+ random_setup_func=None,
171
+ random_setup_kwargs=None,
172
+ role_assignator_func=None,
173
+ role_assignator_func_kwargs=None,
174
+ finalization_visibility=False,
175
+ other_values_visibility=False,
176
+ random_seed=None
177
+ ):
178
+ """Initialize the Deal or No Deal environment.
179
+
180
+ Args:
181
+ agents: List of agent IDs participating in the game
182
+ mode: Game mode ("coop" or "comp")
183
+ max_messages: Maximum number of messages per agent per round
184
+ min_messages: Minimum number of messages per agent per round
185
+ max_chars_per_message: Maximum characters per message
186
+ rounds_per_game: Number of negotiation rounds to play
187
+ random_setup_func: Function to generate item quantities and values
188
+ random_setup_kwargs: Arguments for the random setup function
189
+ role_assignator_func: Function to assign roles to agents
190
+ role_assignator_func_kwargs: Arguments for the role assignator
191
+ finalization_visibility: Whether agents can see each other's finalizations
192
+ other_values_visibility: Whether agents can see each other's values
193
+ random_seed: Seed for reproducibility
194
+ """
195
+ # ...
196
+
197
+ def reset(self):
198
+ """Reset the environment to an initial state and return the initial observation.
199
+
200
+ Returns:
201
+ observation (dict): A dictionary where keys are agent identifiers and values are observations.
202
+ """
203
+ # ...
204
+
205
+ def step(self, actions):
206
+ """Take a step in the environment using the provided actions.
207
+
208
+ Args:
209
+ actions (dict): A dictionary where keys are agent identifiers and values are actions.
210
+ Actions can be messages or finalization proposals.
211
+
212
+ Returns:
213
+ observations (dict): A dictionary where keys are agent identifiers and values are observations.
214
+ done (bool): Whether the episode has ended.
215
+ info (dict): Additional information about the environment.
216
+ """
217
+ # ...
218
+
219
+ def get_state(self):
220
+ """Retrieve the current state of the game.
221
+
222
+ Returns:
223
+ state (dict): The current state of the game, including items, quantities, values, etc.
224
+ """
225
+ # ...
226
+
227
+ Key Implementation Details
228
+ ~~~~~~~~~~~~~~~~~~~~~~~~~
229
+
230
+ The ``DondEnv`` class implements several key features:
231
+
232
+ 1. **Multi-Agent Support**: The environment tracks two agents and manages their alternating messages.
233
+
234
+ 2. **Turn-Based Dialogue**: The environment enforces turn structure and limits on message count.
235
+
236
+ 3. **Finalization Processing**: The environment validates and processes finalization proposals.
237
+
238
+ 4. **Random Setup**: The environment supports multiple methods of generating negotiation scenarios.
239
+
240
+ 5. **Round Management**: The environment can handle multiple rounds with different setups.
241
+
242
+ Observation Structure
243
+ ~~~~~~~~~~~~~~~~~~~~
244
+
245
+ Each agent receives an observation (state) dictionary with rich information about the game:
246
+
247
+ .. code-block:: python
248
+
249
+ {
250
+ "mode": str, # Game mode ("coop" or "comp")
251
+ "role_values": dict, # Value mappings for each role
252
+ "role_props": dict, # Properties for each role
253
+ "agent_to_role": dict, # Mapping from agent IDs to roles
254
+ "is_new_round": bool, # Whether this is the start of a new round
255
+ "is_new_game": bool, # Whether this is the start of a new game
256
+ "game_over": bool, # Whether the game is over
257
+ "items": list, # List of item names
258
+ "quantities": dict, # Quantities of each item
259
+ "has_finalized": bool, # Whether finalization has been proposed
260
+ "last_message": dict, # The last message sent
261
+ "messages_remaining": dict, # Number of messages each agent can still send
262
+ # And various history tracking fields
263
+ }
264
+
265
+ Action Structure
266
+ ~~~~~~~~~~~~~~~
267
+
268
+ Actions can be:
269
+
270
+ 1. **Text Messages**: Free-form text for negotiation.
271
+ 2. **Finalization Proposals**: Structured data specifying the exact allocation of items.
272
+
273
+ Example finalization format:
274
+
275
+ .. code-block:: python
276
+
277
+ {
278
+ "type": "finalize",
279
+ "allocation": {
280
+ "agent1": {"book": 3, "hat": 0, "ball": 6},
281
+ "agent2": {"book": 1, "hat": 2, "ball": 0}
282
+ }
283
+ }
284
+
285
+ Value Setup Functions
286
+ --------------------
287
+
288
+ The DoND environment provides several functions for setting up item values:
289
+
290
+ .. code-block:: python
291
+
292
+ def dond_random_setup(items, min_quant, max_quant, min_val, max_val, random_seed=None):
293
+ """
294
+ Generates items, even-numbered quantities and distinct random values for each category for both agents.
295
+
296
+ Args:
297
+ items (list): List of items.
298
+ min_quant (int): Minimum quantity per item.
299
+ max_quant (int): Maximum quantity per item.
300
+ min_val (int): Minimum value per item.
301
+ max_val (int): Maximum value per item.
302
+ random_seed (int, optional): Seed for random generation.
303
+
304
+ Returns:
305
+ tuple: (items, quantities, (val_starting_negotiator, val_responding_negotiator))
306
+ """
307
+ # ...
308
+
309
+ def independent_random_vals(items, min_quant, max_quant, min_val, max_val, random_seed=None):
310
+ """
311
+ Generates random quantities and independent random values for both agents.
312
+
313
+ Args:
314
+ Similar to dond_random_setup
315
+
316
+ Returns:
317
+ tuple: (items, quantities, (val_starting_negotiator, val_responding_negotiator))
318
+ """
319
+ # ...
320
+
321
+ def bicameral_vals_assignator(items, min_quant, max_quant, low_val_mean, low_val_std, high_val_mean, high_val_std, random_seed=None):
322
+ """
323
+ Generates values with a bicameral distribution - each agent values half the items highly.
324
+
325
+ Args:
326
+ items (list): List of items.
327
+ min_quant, max_quant: Range for quantities
328
+ low_val_mean, low_val_std: Mean and standard deviation for the "low value" distribution
329
+ high_val_mean, high_val_std: Mean and standard deviation for the "high value" distribution
330
+ random_seed: Seed for reproducibility
331
+
332
+ Returns:
333
+ tuple: (items, quantities, (val_starting_negotiator, val_responding_negotiator))
334
+ """
335
+ # ...
336
+
337
+ Running DoND Games
338
+ ----------------------
339
+
340
+ To run Deal or No Deal games with LLM agents, you can use the following structure:
341
+
342
+ .. code-block:: python
343
+
344
+ from mllm.environments.dond.dond_game import DondEnv
345
+ from mllm.environments.dond.dond_agent import DondAgent
346
+ from src.run_matches import run_batched_matches
347
+
348
+ # Create environment
349
+ env = DondEnv(
350
+ agents=["agent1", "agent2"],
351
+ mode="coop",
352
+ max_messages=10,
353
+ rounds_per_game=1,
354
+ random_setup_func="dond_random_setup",
355
+ random_setup_kwargs={
356
+ "items": ["book", "hat", "ball"],
357
+ "min_quant": 2,
358
+ "max_quant": 8,
359
+ "min_val": 1,
360
+ "max_val": 10
361
+ },
362
+ finalization_visibility=False
363
+ )
364
+
365
+ # Create agent handlers (implementation details would vary)
366
+ agent_handlers = {
367
+ "agent1": DondAgent(agent_id="agent1"),
368
+ "agent2": DondAgent(agent_id="agent2")
369
+ }
370
+
371
+ # Define policy mapping
372
+ policy_mapping = {
373
+ "llm_policy": my_llm_policy_function
374
+ }
375
+
376
+ # Run the game
377
+ game_results = run_batched_matches(
378
+ envs=[env],
379
+ agent_handlers_per_env=[agent_handlers],
380
+ policy_mapping=policy_mapping,
381
+ max_parallel_matches=1
382
+ )
383
+
384
+ Limitations and Considerations
385
+ -----------------------------
386
+
387
+ 1. **Negotiation Complexity**: The open-ended nature of negotiations can be challenging for some LLM agents.
388
+
389
+ 2. **Parsing Challenges**: Extracting structured finalization proposals from free-form text requires robust parsing.
390
+
391
+ 3. **Optimization Opportunities**: Different agents may employ different negotiation strategies to optimize outcomes.
392
+
393
+ 4. **Fairness Evaluation**: The environment allows research into questions of fair division and Pareto optimality.
394
+
395
+ 5. **Strategic Deception**: Agents might strategically misrepresent their true values, adding complexity to negotiations.
396
+
397
+ Advanced Usage
398
+ ------------
399
+
400
+ For advanced usage, you can:
401
+
402
+ 1. **Custom Value Functions**: Create more complex distributions of item values for specific research questions.
403
+
404
+ 2. **Novel Negotiation Scenarios**: Design item sets and values to test specific negotiation skills.
405
+
406
+ 3. **Curriculum Learning**: Create progressively more difficult negotiation scenarios.
407
+
408
+ 4. **Communication Analysis**: Analyze the language and strategies used in successful negotiations.
409
+
410
+ 5. **Multi-Round Dynamics**: Study how agents adapt their strategies over multiple rounds.
src_code_for_reproducibility/docs/source/environments/ipd.rst ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ =================
2
+ Iterated Prisoner's Dilemma
3
+ =================
4
+
5
+ The Iterated Prisoner's Dilemma environment provides a classic game theory setting for studying cooperation
6
+ and competition between agents. This document describes the API for interacting with the IPD environment
7
+ and its associated agent handler.
8
+
9
+ Overview
10
+ --------
11
+
12
+ The Prisoner's Dilemma is a fundamental problem in game theory that demonstrates why two rational individuals might not
13
+ cooperate, even when it appears in their best interest to do so. In the iterated version, the same two players
14
+ repeatedly face the same dilemma, allowing for the development of trust or retaliation based on previous interactions.
15
+
16
+ Our implementation follows the Multi-Agent Negotiation Environment standard, allowing it to be used with
17
+ LLM agents through a text-based interface.
18
+
19
+ Game Rules
20
+ ----------
21
+
22
+ ### Basic Premise
23
+
24
+ The scenario behind the Prisoner's Dilemma is as follows:
25
+
26
+ Two criminals are arrested and imprisoned. Each prisoner is in solitary confinement with no means of communicating with
27
+ the other. The prosecutors lack sufficient evidence to convict the pair on the principal charge, but they have enough
28
+ to convict both on a lesser charge. Simultaneously, the prosecutors offer each prisoner a bargain:
29
+
30
+ - If both prisoners betray each other, each serves 2 years in prison (the "punishment" payoff)
31
+ - If one betrays the other while the other remains silent, the betrayer goes free (the "temptation" payoff) while the
32
+ silent accomplice serves 3 years (the "sucker" payoff)
33
+ - If both remain silent, each serves only 1 year in prison (the "reward" payoff)
34
+
35
+ ### Game Mechanics
36
+
37
+ In our implementation, the choices are simplified to:
38
+ - **C**: Cooperate (remain silent)
39
+ - **D**: Defect (betray the other prisoner)
40
+
41
+ Each round, both players simultaneously choose either C or D, and receive points based on the combination of their choices:
42
+
43
+ - Both choose C: Both receive the "reward" payoff (3 points by default)
44
+ - Both choose D: Both receive the "punishment" payoff (1 point by default)
45
+ - One chooses C, one chooses D: The defector receives the "temptation" payoff (5 points by default), while the cooperator
46
+ receives the "sucker" payoff (0 points by default)
47
+
48
+ ### Example: Single Round
49
+
50
+ Let's see how a single round plays out:
51
+
52
+ 1. Alice and Bob simultaneously make their choices
53
+ 2. If Alice chooses C and Bob chooses C:
54
+ - Alice receives 3 points
55
+ - Bob receives 3 points
56
+ 3. If Alice chooses C and Bob chooses D:
57
+ - Alice receives 0 points
58
+ - Bob receives 5 points
59
+ 4. If Alice chooses D and Bob chooses C:
60
+ - Alice receives 5 points
61
+ - Bob receives 0 points
62
+ 5. If Alice chooses D and Bob chooses D:
63
+ - Alice receives 1 point
64
+ - Bob receives 1 point
65
+
66
+ ### Iterated Game Structure
67
+
68
+ The iterated version repeats this basic game for a fixed number of rounds. The key features are:
69
+
70
+ 1. Players know the total number of rounds in advance
71
+ 2. After each round, players learn what choice the other player made
72
+ 3. Players maintain a cumulative score across all rounds
73
+ 4. Players can adjust their strategy based on the history of previous interactions
74
+
75
+ ### Game Variations
76
+
77
+ The IPD environment supports several variations through configuration parameters:
78
+
79
+ #### Different Payoff Matrices
80
+
81
+ The standard payoff values can be modified to create different incentive structures:
82
+ - **Traditional PD**: reward=3, punishment=1, temptation=5, sucker=0
83
+ - **Weak Temptation**: reward=3, punishment=1, temptation=4, sucker=0 (reduces the incentive to defect)
84
+ - **Harsh Punishment**: reward=3, punishment=0, temptation=5, sucker=0 (increases the cost of mutual defection)
85
+ - **Generous**: reward=4, punishment=2, temptation=5, sucker=1 (cushions the blow of being betrayed)
86
+
87
+ #### Game Length Variations
88
+
89
+ The number of rounds can significantly impact strategy:
90
+ - **Short Games** (5-10 rounds): Incentivizes more defection, especially near the end
91
+ - **Medium Games** (20-50 rounds): Allows for the development of tit-for-tat and forgiveness strategies
92
+ - **Long Games** (100+ rounds): Favors steady cooperation with occasional "probing" defections
93
+
94
+ ### Common Strategies
95
+
96
+ While not enforced by the environment, several well-known strategies can emerge:
97
+ - **Always Cooperate**: Always choose C
98
+ - **Always Defect**: Always choose D
99
+ - **Tit for Tat**: Start with C, then copy what the opponent did in the previous round
100
+ - **Forgiving Tit for Tat**: Like Tit for Tat, but occasionally cooperate even after being defected against
101
+ - **Grudger**: Cooperate until the opponent defects once, then always defect
102
+ - **Random**: Choose randomly between C and D
103
+
104
+ IPDEnv
105
+ ------
106
+
107
+ The ``IPDEnv`` class provides an interface to the Iterated Prisoner's Dilemma environment that follows the
108
+ Multi-Agent Negotiation Environment standard.
109
+
110
+ .. code-block:: python
111
+
112
+ class IPDEnv:
113
+ """
114
+ Iterated Prisoner's Dilemma environment following the MarlEnvironment standard.
115
+
116
+ In each round of the game, two agents simultaneously choose to either cooperate (C) or defect (D).
117
+ The payoffs are as follows:
118
+ - If both cooperate: Both receive the "reward" (usually 3 points)
119
+ - If both defect: Both receive the "punishment" (usually 1 point)
120
+ - If one cooperates and one defects: The defector receives the "temptation" (usually 5 points)
121
+ and the cooperator receives the "sucker" payoff (usually 0 points)
122
+
123
+ The game is played for a specified number of rounds.
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ rounds_per_game: int = 10,
129
+ reward: float = 3.0, # Both cooperate
130
+ punishment: float = 1.0, # Both defect
131
+ temptation: float = 5.0, # Defector's reward when other cooperates
132
+ sucker: float = 0.0, # Cooperator's reward when other defects
133
+ random_seed: Optional[int] = None,
134
+ ):
135
+ """
136
+ Initialize the Iterated Prisoner's Dilemma environment.
137
+
138
+ Args:
139
+ rounds_per_game: Number of rounds to play
140
+ reward: Payoff when both agents cooperate
141
+ punishment: Payoff when both agents defect
142
+ temptation: Payoff for defecting when other agent cooperates
143
+ sucker: Payoff for cooperating when other agent defects
144
+ seed: Random seed for reproducibility
145
+ """
146
+ # ...
147
+
148
+ def reset(self) -> Dict[str, Dict[str, Any]]:
149
+ """
150
+ Reset the environment to an initial state and return the initial observation.
151
+
152
+ Returns:
153
+ observation (dict): A dictionary where keys are agent identifiers and values are observations.
154
+ """
155
+ # ...
156
+
157
+ def step(self, actions: Dict[str, str]) -> Tuple[Dict[str, Dict[str, Any]], bool, Dict[str, Any]]:
158
+ """
159
+ Take a step in the environment using the provided actions.
160
+
161
+ Args:
162
+ actions (dict): A dictionary where keys are agent identifiers and values are actions ('C' or 'D').
163
+
164
+ Returns:
165
+ observations (dict): A dictionary where keys are agent identifiers and values are observations.
166
+ done (bool): Whether the episode has ended.
167
+ info (dict): Additional information about the environment.
168
+ """
169
+ # ...
170
+
171
+ Key Implementation Details
172
+ ~~~~~~~~~~~~~~~~~~~~~~~~~
173
+
174
+ The ``IPDEnv`` class implements several key features:
175
+
176
+ 1. **Two-Agent Support**: The environment tracks two agents ("alice" and "bob") and manages their interactions.
177
+
178
+ 2. **Round-Based Play**: The environment enforces turn structure and tracks game history.
179
+
180
+ 3. **Payoff Matrix**: The environment calculates rewards based on the standard prisoner's dilemma payoff matrix.
181
+
182
+ 4. **Observation Generation**: The environment generates detailed observations for each agent, including action history and rewards.
183
+
184
+ 5. **Game Termination**: The environment tracks game termination after the specified number of rounds.
185
+
186
+ Observation Structure
187
+ ~~~~~~~~~~~~~~~~~~~~
188
+
189
+ Each agent receives an observation dictionary with the following structure:
190
+
191
+ .. code-block:: python
192
+
193
+ {
194
+ "current_round": int, # Current round number (0-indexed)
195
+ "rounds_per_game": int, # Total number of rounds in the game
196
+ "history": List[Dict], # Complete game history so far
197
+ "last_round_actions": Dict[str, str], # Actions from the previous round (if any)
198
+ "last_round_reward": float, # Reward received in the previous round (if any)
199
+ "total_reward": float, # Cumulative reward so far
200
+ "payoff_matrix": Dict[str, float], # The game's payoff matrix values
201
+ }
202
+
203
+ Action Structure
204
+ ~~~~~~~~~~~~~~~
205
+
206
+ Actions are simple strings:
207
+
208
+ 1. ``"C"`` for Cooperate
209
+ 2. ``"D"`` for Defect
210
+
211
+ IPDAgent
212
+ --------------
213
+
214
+ The ``IPDAgent`` class implements the agent handler interface for the Iterated Prisoner's Dilemma, processing observations from the environment and generating actions through an LLM.
215
+
216
+ .. code-block:: python
217
+
218
+ class IPDAgent:
219
+ """
220
+ Agent handler for Iterated Prisoner's Dilemma, implementing the AgentState interface
221
+ for the multi-agent negotiation standard.
222
+ """
223
+
224
+ def __init__(
225
+ self,
226
+ agent_id: str,
227
+ policy_id: str = "llm_policy",
228
+ system_prompt: Optional[str] = None,
229
+ max_errors: int = 3,
230
+ opponent_id: Optional[str] = None,
231
+ ):
232
+ """
233
+ Initialize the IPD agent handler.
234
+
235
+ Args:
236
+ agent_id: Identifier for this agent ("alice" or "bob")
237
+ policy_id: Identifier for the policy this agent uses
238
+ system_prompt: Optional custom system prompt for the LLM
239
+ max_errors: Maximum number of parsing errors before defaulting to cooperate
240
+ opponent_id: Optional identifier of the opponent (inferred if not provided)
241
+ """
242
+ # ...
243
+
244
+ def step(self, observation_from_env: Dict[str, Any], policy_output: str = None) -> Tuple[str, Dict[str, Any], str, bool, Dict[str, Any]]:
245
+ """
246
+ Update the agent state based on the observation and process the policy output.
247
+
248
+ Args:
249
+ observation_from_env: The observation from the environment
250
+ policy_output: The output from the policy (LLM response)
251
+
252
+ Returns:
253
+ policy_id: The policy identifier
254
+ policy_input: The input to the policy
255
+ action: The action to be sent to the environment
256
+ done: Whether the action is ready to be sent to the environment
257
+ info: Additional information about the agent
258
+ """
259
+ # ...
260
+
261
+ Key Implementation Details
262
+ ~~~~~~~~~~~~~~~~~~~~~~~~~
263
+
264
+ The ``IPDAgent`` class implements several key features:
265
+
266
+ 1. **LLM Interaction**: The agent generates prompts for an LLM and processes the LLM's responses.
267
+
268
+ 2. **Action Extraction**: The agent parses the LLM's output to extract valid actions (C or D).
269
+
270
+ 3. **Error Handling**: The agent provides helpful error messages when parsing fails and defaults to cooperation after multiple failures.
271
+
272
+ 4. **History Tracking**: The agent maintains and provides the complete game history in its prompts.
273
+
274
+ 5. **Strategy Explanation**: The agent can extract and log the reasoning behind an LLM's decisions.
275
+
276
+ Prompt Structure
277
+ ~~~~~~~~~~~~~~~
278
+
279
+ The agent generates prompts that include:
280
+
281
+ 1. **System Prompt**: Instructions and context for the LLM, explaining its role and the rules of the Prisoner's Dilemma.
282
+
283
+ 2. **Game State Description**: A text description of the current game state, including:
284
+ - Current round number
285
+ - History of previous rounds (if any)
286
+ - Cumulative score
287
+
288
+ 3. **Action Request**: Instructions on how to format the response, requiring an explicit action tag.
289
+
290
+ Example system prompt:
291
+
292
+ .. code-block:: text
293
+
294
+ You are playing as Alice in an Iterated Prisoner's Dilemma game against Bob.
295
+ In each round, you must choose to either Cooperate (C) or Defect (D).
296
+
297
+ The payoffs are:
298
+ - If both players Cooperate: You each get 3 points
299
+ - If both players Defect: You each get 1 point
300
+ - If you Cooperate and Bob Defects: You get 0 points, Bob gets 5 points
301
+ - If you Defect and Bob Cooperates: You get 5 points, Bob gets 0 points
302
+
303
+ Your goal is to maximize your total points across all rounds.
304
+ The game will last for exactly 10 rounds, and both players know this.
305
+
306
+ Example game state prompt:
307
+
308
+ .. code-block:: text
309
+
310
+ Current round: 3/10
311
+
312
+ History:
313
+ Round 1: You chose C, Bob chose C. You earned 3 points.
314
+ Round 2: You chose C, Bob chose D. You earned 0 points.
315
+
316
+ Your total score so far: 3 points
317
+
318
+ What is your choice for round 3?
319
+ Please respond with <action>C</action> to cooperate or <action>D</action> to defect,
320
+ and explain your reasoning.
321
+
322
+ Running IPD Games
323
+ ----------------------
324
+
325
+ To run Iterated Prisoner's Dilemma games with LLM agents, you can use the following code structure:
326
+
327
+ .. code-block:: python
328
+
329
+ from mllm.environments.ipd.ipd_game import IPDEnv
330
+ from mllm.environments.ipd.ipd_agent import IPDAgent
331
+ from mllm.run_matches import run_batched_matches
332
+
333
+ # Create environment
334
+ env = IPDEnv(
335
+ rounds_per_game=10,
336
+ reward=3.0,
337
+ punishment=1.0,
338
+ temptation=5.0,
339
+ sucker=0.0
340
+ )
341
+
342
+ # Create agent handlers
343
+ agent_handlers = {
344
+ "alice": IPDAgent(agent_id="alice"),
345
+ "bob": IPDAgent(agent_id="bob")
346
+ }
347
+
348
+ # Define policy mapping
349
+ policy_mapping = {
350
+ "llm_policy": my_llm_policy_function
351
+ }
352
+
353
+ # Run the game
354
+ game_results = run_batched_matches(
355
+ envs=[env],
356
+ agent_handlers_per_env=[agent_handlers],
357
+ policy_mapping=policy_mapping,
358
+ max_parallel_matches=1
359
+ )
360
+
361
+ # Process results
362
+ for result in game_results:
363
+ print(f"Game finished. Scores: {result['total_rewards']}")
364
+
365
+ Statistics and Analysis
366
+ ----------------------
367
+
368
+ The IPD environment includes utility functions for analyzing game outcomes:
369
+
370
+ 1. **Cooperation Rates**: Percentage of rounds where each agent cooperated.
371
+ 2. **Mutual Cooperation/Defection**: Percentage of rounds where both agents made the same choice.
372
+ 3. **Score Distribution**: Analysis of how points were accumulated over the game.
373
+
374
+ These statistics can be calculated using the ``gather_ipd_statistics`` function:
375
+
376
+ .. code-block:: python
377
+
378
+ from mllm.environments.ipd.ipd_statistics_funcs import gather_ipd_statistics
379
+
380
+ stats = gather_ipd_statistics(match_info, env_info)
381
+ print(f"Cooperation rates: {stats['cooperation_rate']}")
382
+ print(f"Mutual cooperation rate: {stats['mutual_cooperation_rate']}")
383
+ print(f"Mutual defection rate: {stats['mutual_defection_rate']}")
384
+
385
+ Limitations and Considerations
386
+ -----------------------------
387
+
388
+ 1. **Determinism**: The environment is deterministic, with randomness only in initialization if a seed is provided.
389
+
390
+ 2. **Limited Player Count**: The IPD environment only supports exactly two players.
391
+
392
+ 3. **Perfect Information**: Both players have perfect information about the game history.
393
+
394
+ 4. **Simultaneous Actions**: Both players act simultaneously, which requires adaptations for some LLM interfaces.
395
+
396
+ 5. **Fixed Game Length**: The total number of rounds is fixed and known to both players from the start.
397
+
398
+ Advanced Usage
399
+ ------------
400
+
401
+ For advanced usage, you can customize:
402
+
403
+ 1. **Payoff Matrix**: Modify reward values to create different incentive structures.
404
+
405
+ 2. **System Prompts**: Customize the LLM's understanding of the game and potential strategies.
406
+
407
+ 3. **Error Handling**: Adjust how the agent responds to invalid LLM outputs.
408
+
409
+ 4. **Analysis**: Create custom statistics gathering for specific research questions.
410
+
411
+ 5. **Integration**: Connect the IPD environment to other negotiation frameworks or tournament systems.
src_code_for_reproducibility/docs/source/index.rst ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Welcome to LLM Negotiation's documentation!
2
+ ===========================================
3
+ This library is a collection of tools for training and evaluating LLM-based agents in multi-agent environments. It is designed to be easy to use and extend.
4
+
5
+ .. toctree::
6
+ :maxdepth: 3
7
+ :caption: Contents:
8
+
9
+ installation
10
+ marl_standard
11
+ environments
12
+ launch
13
+ usage
14
+ modules
15
+ contributing
16
+
17
+ Indices and tables
18
+ ==================
19
+
20
+ * :ref:`genindex`
21
+ * :ref:`modindex`
22
+ * :ref:`search`
src_code_for_reproducibility/docs/source/launch.rst ADDED
File without changes
src_code_for_reproducibility/docs/source/media/runbatch.png ADDED
src_code_for_reproducibility/docs/source/src.environments.dond.dond_return_funcs.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.environments.dond.dond\_return\_funcs module
2
+ ================================================
3
+
4
+ .. automodule:: src.environments.dond.dond_return_funcs
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.environments.dond.dond_statistics_funcs.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.environments.dond.dond\_statistics\_funcs module
2
+ ====================================================
3
+
4
+ .. automodule:: src.environments.dond.dond_statistics_funcs
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.environments.env_imports.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.environments.env\_imports module
2
+ ====================================
3
+
4
+ .. automodule:: src.environments.env_imports
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.models.hf_agent.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.models.hf\_agent module
2
+ ===========================
3
+
4
+ .. automodule:: src.models.hf_agent
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.models.new_local_llm.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.models.new\_local\_llm module
2
+ =================================
3
+
4
+ .. automodule:: src.models.new_local_llm
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.models.oai_agent.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.models.oai\_agent module
2
+ ============================
3
+
4
+ .. automodule:: src.models.oai_agent
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.models.server_llm.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.models.server\_llm module
2
+ =============================
3
+
4
+ .. automodule:: src.models.server_llm
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.models.vllm_worker_wrap.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.models.vllm\_worker\_wrap module
2
+ ====================================
3
+
4
+ .. automodule:: src.models.vllm_worker_wrap
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.run.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.run module
2
+ ==============
3
+
4
+ .. automodule:: src.run
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.utils.extra_stats.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.utils.extra\_stats module
2
+ =============================
3
+
4
+ .. automodule:: src.utils.extra_stats
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.utils.inherit_args.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.utils.inherit\_args module
2
+ ==============================
3
+
4
+ .. automodule:: src.utils.inherit_args
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.utils.log_statistics.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.utils.log\_statistics module
2
+ ================================
3
+
4
+ .. automodule:: src.utils.log_statistics
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.utils.model_to_cpu.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.utils.model\_to\_cpu module
2
+ ===============================
3
+
4
+ .. automodule:: src.utils.model_to_cpu
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.utils.quick_stats.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.utils.quick\_stats module
2
+ =============================
3
+
4
+ .. automodule:: src.utils.quick_stats
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/usage.rst ADDED
File without changes
src_code_for_reproducibility/markov_games/__pycache__/group_timesteps.cpython-312.pyc ADDED
Binary file (6.17 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/run_markov_games.cpython-312.pyc ADDED
Binary file (1.14 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-312.pyc ADDED
Binary file (3.9 kB). View file
 
src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging_for_training.py ADDED
File without changes
src_code_for_reproducibility/markov_games/markov_game.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This class unifies a simulation, and the agents acting in it (see `simulation.py` & `agent.py`).
3
+ In a MarkovGame step,
4
+ 1) each agent takes an action,
5
+ 2) the state transitions with respect to these actions,
6
+ 3) all relevant data of the step is appended to the historical data list
7
+
8
+ In order to perform 3), the agents and the simulation are expected, at each time step,
9
+ to return a log of the state transition (from their perspective).
10
+ For instance, the Simulation might send rewards and the agents might send prompting contexts to be used later to generate the training data.
11
+ A different approach would be to simply have the agents keep their data private and log it upon completion of a trajectory.
12
+ The approach we use here centralizes the data gathering aspect,
13
+ making it easy to create sub-trajectories (in the `runners` defined in `runners.py`) descriptions that
14
+ only log information for step transitions occuring after the branching out.
15
+ """
16
+ import asyncio
17
+ import copy
18
+ import json
19
+ import os
20
+ from dataclasses import dataclass
21
+ from typing import Any, List, Literal, Optional, Tuple
22
+
23
+ from transformers.models.idefics2 import Idefics2Config
24
+
25
+ from mllm.markov_games.agent import Agent
26
+ from mllm.markov_games.rollout_tree import AgentActLog, StepLog
27
+ from mllm.markov_games.simulation import Simulation
28
+
29
+ AgentId = str
30
+
31
+
32
+ @dataclass
33
+ class AgentAndActionSafeCopy:
34
+ action: Any
35
+ action_info: AgentActLog
36
+ agent_after_action: type[Agent]
37
+
38
+
39
+ class MarkovGame(object):
40
+ def __init__(
41
+ self,
42
+ id: int,
43
+ agents: dict[AgentId, type[Agent]],
44
+ simulation: type[Simulation],
45
+ crn_id: int,
46
+ ):
47
+ """
48
+ Args:
49
+ agents:
50
+ output_path:
51
+ Path where the step infos are saved.
52
+ simulation:
53
+ Simulation object. Example: IPDSimulation
54
+ """
55
+ self.agents = agents
56
+ self.agent_ids = self.agents.keys()
57
+ self.simulation = simulation
58
+ self.simulation_step_log = None
59
+ self.agent_step_logs = {agent_id: None for agent_id in self.agent_ids}
60
+ self.actions = {}
61
+ self.id = id
62
+ self.crn_id = crn_id
63
+
64
+ def get_id(self) -> str:
65
+ return self.id
66
+
67
+ def get_crn_id(self) -> int:
68
+ return self.crn_id
69
+
70
+ def get_agent_ids(self) -> List[AgentId]:
71
+ return list(self.agent_ids)
72
+
73
+ async def get_action_of_agent_without_side_effects(
74
+ self, agent_id: AgentId
75
+ ) -> Tuple[Any, AgentActLog]:
76
+ """
77
+ Safe function to get an action of an agent without modifying the agent or the simulation.
78
+ """
79
+ agent = self.agents[agent_id]
80
+ agent_before_action = agent.get_safe_copy()
81
+ obs = self.simulation.get_obs_agent(agent_id)
82
+ action, action_info = await agent.act(observation=obs)
83
+ self.agents[agent_id] = agent_before_action
84
+ agent_after_action = agent.get_safe_copy()
85
+ return AgentAndActionSafeCopy(action, action_info, agent_after_action)
86
+
87
+ async def get_actions_of_agents_without_side_effects(
88
+ self,
89
+ ) -> dict[AgentId, AgentAndActionSafeCopy]:
90
+ """
91
+ Safe function to get an action of an agent without modifying the agent or the simulation.
92
+ """
93
+ tasks = []
94
+ for agent_id in self.agent_ids:
95
+ task = asyncio.create_task(
96
+ self.get_action_of_agent_without_side_effects(agent_id)
97
+ )
98
+ tasks.append(task)
99
+ agent_and_action_safe_copies: list[
100
+ AgentAndActionSafeCopy
101
+ ] = await asyncio.gather(*tasks)
102
+ return {
103
+ agent_id: agent_and_action_safe_copy
104
+ for agent_id, agent_and_action_safe_copy in zip(
105
+ self.agent_ids, agent_and_action_safe_copies
106
+ )
107
+ }
108
+
109
+ def set_action_and_agent_after_action_manually(
110
+ self,
111
+ agent_id: AgentId,
112
+ agent_action_safe_copy: AgentAndActionSafeCopy,
113
+ ):
114
+ """
115
+ Set the action and the agent after action manually.
116
+ """
117
+ self.actions[agent_id] = agent_action_safe_copy.action
118
+ self.agent_step_logs[agent_id] = agent_action_safe_copy.action_info
119
+ self.agents[agent_id] = agent_action_safe_copy.agent_after_action
120
+
121
+ def set_actions_of_agents_manually(
122
+ self, actions: dict[AgentId, AgentAndActionSafeCopy]
123
+ ):
124
+ """
125
+ Set the actions of agents manually.
126
+ """
127
+ for agent_id, agent_action_safe_copy in actions.items():
128
+ self.set_action_and_agent_after_action_manually(
129
+ agent_id, agent_action_safe_copy
130
+ )
131
+
132
+ async def set_action_of_agent(self, agent_id: AgentId):
133
+ """
134
+ TOWRITE
135
+ """
136
+ agent = self.agents[agent_id]
137
+ obs = self.simulation.get_obs_agent(agent_id)
138
+ action, action_info = await agent.act(observation=obs)
139
+ self.actions[agent_id] = action
140
+ self.agent_step_logs[agent_id] = action_info
141
+
142
+ async def set_actions(self):
143
+ """
144
+ TOWRITE
145
+ """
146
+ # background_tasks = set()
147
+ tasks = []
148
+ for agent_id in self.agent_ids:
149
+ task = asyncio.create_task(self.set_action_of_agent(agent_id))
150
+ tasks.append(task)
151
+ await asyncio.gather(*tasks)
152
+
153
+ def take_simulation_step(self):
154
+ """
155
+ TOWRITE
156
+ """
157
+ terminated, self.simulation_step_log = self.simulation.step(self.actions)
158
+ return terminated
159
+
160
+ def get_step_log(self) -> StepLog:
161
+ """
162
+ TOWRITE
163
+ TODO: assert actions and simulation have taken step
164
+ """
165
+ step_log = StepLog(
166
+ simulation_step_log=self.simulation_step_log,
167
+ action_logs=self.agent_step_logs,
168
+ )
169
+ return step_log
170
+
171
+ async def step(self) -> Tuple[bool, StepLog]:
172
+ """
173
+ TOWRITE
174
+ """
175
+ await self.set_actions()
176
+ terminated = self.take_simulation_step()
177
+ step_log = self.get_step_log()
178
+ return terminated, step_log
179
+
180
+ def get_safe_copy(self):
181
+ """
182
+ TOWRITE
183
+ """
184
+
185
+ new_markov_game = copy.copy(self)
186
+ new_simulation = self.simulation.get_safe_copy()
187
+ new_agents = {
188
+ agent_id: agent.get_safe_copy() for agent_id, agent in self.agents.items()
189
+ }
190
+
191
+ # Reassign copied components
192
+ new_markov_game.simulation = new_simulation
193
+ new_markov_game.agents = new_agents
194
+
195
+ # IMPORTANT: ensure agent_ids references the new agents dict, not the original
196
+ new_markov_game.agent_ids = new_markov_game.agents.keys()
197
+
198
+ # Deep-copy step data to avoid correlation
199
+ new_markov_game.simulation_step_log = copy.deepcopy(self.simulation_step_log)
200
+ new_markov_game.actions = copy.deepcopy(self.actions)
201
+ # Rebuild logs to align exactly with new agent ids
202
+ old_agent_step_logs = copy.deepcopy(self.agent_step_logs)
203
+ new_markov_game.agent_step_logs = {
204
+ agent_id: old_agent_step_logs.get(agent_id)
205
+ for agent_id in new_markov_game.agent_ids
206
+ }
207
+
208
+ return new_markov_game
src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_simulation.cpython-312.pyc ADDED
Binary file (9.06 kB). View file
 
src_code_for_reproducibility/markov_games/negotiation/nego_hard_coded_policies.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import Optional
3
+ from mllm.markov_games.negotiation.nego_agent import NegotiationAgent
4
+ from mllm.markov_games.negotiation.no_press_nego_agent import NoPressAgent
5
+ from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressObs
6
+ from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn
7
+ from mllm.markov_games.negotiation.nego_simulation import Split
8
+ from typing import Any, Tuple
9
+
10
+ class HardCodedNegoWelfareMaximizingPolicy(NoPressAgent):
11
+ async def act(self, observation: NoPressObs) -> Tuple[Any, AgentActLog]:
12
+ """
13
+ Policy that gives all of the items to the agent who values them more.
14
+ If the items are equally valued, give them to the agent who values them more.
15
+ """
16
+ quantities = observation.quantities
17
+ my_values = observation.value
18
+ other_values = observation.other_value
19
+
20
+ items_given_to_self = {}
21
+ for item, qty in quantities.items():
22
+ my_v = float(my_values.get(item, 0))
23
+ other_v = float(other_values.get(item, 0))
24
+ if my_v == other_v:
25
+ items_given_to_self[item] = int(qty) / 2
26
+ else:
27
+ items_given_to_self[item] = int(qty if my_v > other_v else 0)
28
+
29
+ action = Split(items_given_to_self=items_given_to_self)
30
+ act_log = AgentActLog(
31
+ chat_turns=[
32
+ ChatTurn(
33
+ agent_id=self.agent_id,
34
+ role="assistant",
35
+ content="Using welfare-maximizing split (all to higher-value agent).",
36
+ is_state_end=True,
37
+ )
38
+ ],
39
+ info=None,
40
+ )
41
+ return action, act_log
42
+
43
+ class HardCodedNegoGreedyPolicy(NoPressAgent):
44
+ async def act(self, observation: NoPressObs) -> Tuple[Any, AgentActLog]:
45
+ """
46
+ Always gives itself all of the items.
47
+ """
48
+ quantities = observation.quantities
49
+ items_given_to_self = {item: int(qty) for item, qty in quantities.items()}
50
+
51
+ action = Split(items_given_to_self=items_given_to_self)
52
+ act_log = AgentActLog(
53
+ chat_turns=[
54
+ ChatTurn(
55
+ agent_id=self.agent_id,
56
+ role="assistant",
57
+ content="Using greedy split (keep all items).",
58
+ is_state_end=True,
59
+ )
60
+ ],
61
+ info=None,
62
+ )
63
+ return action, act_log
64
+
src_code_for_reproducibility/markov_games/negotiation/negotiation_statistics.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable, Dict, List, Tuple
4
+
5
+ from mllm.markov_games.negotiation.nego_simulation import Split
6
+ from mllm.markov_games.rollout_tree import SimulationStepLog
7
+
8
+
9
+ def avg_reward(sl: SimulationStepLog) -> List[Tuple[str, float]]:
10
+ """Average (per-step) reward for each agent and overall.
11
+
12
+ What it computes:
13
+ - Returns the raw reward for every (non-buffer) agent at the current
14
+ simulation step.
15
+ - Adds an aggregate key ``all_agents`` which is the simple arithmetic
16
+ mean across the agents present in ``sl.rewards``.
17
+
18
+ Rationale / motivation:
19
+ Monitoring the reward stream at each step helps:
20
+ * Diagnose reward shaping issues (e.g., unintended negative drift).
21
+ * Provide a fairness snapshot (are rewards systematically skewed?).
22
+ * Supply a ubiquitous baseline metric used by other higher‑level
23
+ summaries (efficiency, surplus allocation, etc.).
24
+
25
+ Return shape:
26
+ { agent_id: float, ..., "all_agents": float }
27
+ If any agent id contains the substring "buffer" we treat this step as
28
+ an implementation artifact (e.g., rollout buffer) and return ``None``
29
+ to avoid polluting aggregates.
30
+ """
31
+ for aid in sl.rewards.keys():
32
+ if "buffer" in str(aid) and "live" not in str(aid):
33
+ return None
34
+ # One value per agent at each step
35
+ rewards_dict = {f"reward-{aid}": float(v) for aid, v in (sl.rewards or {}).items()}
36
+ return [(key, value) for key, value in rewards_dict.items() if value is not None]
37
+
38
+
39
+ def split_efficiency(sl: SimulationStepLog) -> List[Tuple[str, float]] | None:
40
+ """Final‑round allocation efficiency relative to an upper bound.
41
+
42
+ What it computes (only on the last timestep of a negotiation round):
43
+ - Uses ``info['values']`` (per‑agent per‑item valuations) and
44
+ ``info['quantities']`` (available item counts) to form a greedy
45
+ *upper bound* on achievable total reward: allocate each unit of an
46
+ item to the single agent who values that item most.
47
+ - Compares the actually realized sum of rewards at that final
48
+ timestep to this constructed maximum.
49
+ - Emits a single scalar under key ``"all_agents"`` equal to
50
+ achieved / theoretical_max.
51
+
52
+ Motivation:
53
+ Efficiency (a core welfare notion) distinguishes between coordination
54
+ failures (low efficiency) versus strategic distributional disputes
55
+ (high efficiency but uneven splits). Tracking this per round helps
56
+ evaluate whether models learn to identify and realize joint surplus.
57
+
58
+ Notes / caveats:
59
+ - Only defined for 2+ non‑buffer agents; if a buffer agent is present
60
+ returns ``None`` to exclude spurious steps.
61
+ - Requires the environment to have populated ``values`` and
62
+ ``quantities``; otherwise returns ``None``.
63
+ - This is an optimistic bound (not necessarily reachable under
64
+ protocol constraints) but is simple, fast, and comparable across
65
+ runs.
66
+ """
67
+ info = sl.info or {}
68
+ if not info or not info.get("is_last_timestep_in_round"):
69
+ return None
70
+ quantities = info.get("quantities") or {}
71
+ values = info.get("values") or {}
72
+ if not values or not quantities:
73
+ return None
74
+ agent_ids = list(sl.rewards.keys())
75
+ if type(values[agent_ids[0]]) is dict:
76
+ item_keys = list(values.values())[0].keys()
77
+ max_vals, max_quantities = [], []
78
+ for item in item_keys:
79
+ max_val = max(float(agent_vals[item]) for agent_vals in values.values())
80
+ max_vals.append(max_val)
81
+ max_quantities.append(quantities[item])
82
+ else:
83
+ max_vals = [max(float(v) for v in values.values())]
84
+ max_quantities = [quantities[item] for item in quantities.keys()]
85
+ for aid in sl.rewards.keys():
86
+ if "buffer" in str(aid) and "live" not in str(aid):
87
+ return None
88
+ achieved = sum(float(v) for v in sl.rewards.values())
89
+ max_reward = sum(d * v for d, v in zip(max_quantities, max_vals))
90
+ # Efficiency is a global metric; emit same value for a special key "all"
91
+ return [("split_efficiency", achieved / max_reward)]
92
+
93
+
94
+ def _extract_items_from_split(raw_split: Dict) -> Dict[str, float] | None:
95
+ """Return a mapping item->proposal amount from a split structure.
96
+
97
+ Supports both generic negotiation splits with nested structure
98
+ { 'items_given_to_self': {item: qty, ...}}
99
+ and TAS coin-only variants which may already be a flat mapping {'coins': qty}.
100
+ """
101
+
102
+ if raw_split is None:
103
+ return {}
104
+ elif isinstance(raw_split, Split):
105
+ return {k: float(v) for k, v in raw_split.items_given_to_self.items()}
106
+ elif isinstance(raw_split, dict):
107
+ if "items_given_to_self" in raw_split and isinstance(
108
+ raw_split["items_given_to_self"], dict
109
+ ):
110
+ return {k: float(v) for k, v in raw_split["items_given_to_self"].items()}
111
+ # Fallback: assume already flat mapping of items
112
+ elif hasattr(raw_split, "items_given_to_self"):
113
+ return {k: float(v) for k, v in raw_split["items_given_to_self"].items()}
114
+ return {
115
+ k: float(v) for k, v in raw_split.items() if isinstance(v, (int, float))
116
+ }
117
+ return {}
118
+
119
+
120
+ def _average_proposal_relative_value(
121
+ sl: SimulationStepLog,
122
+ metric_name: str,
123
+ comparator: Callable[[float, float], bool],
124
+ opposite_comparator: Callable[[float, float], bool],
125
+ ) -> Dict[str, float | None] | None:
126
+ """Shared implementation for proposal size conditioned on relative value.
127
+
128
+ Parameters:
129
+ comparator: returns True when agent_0's value relation (e.g. < or >)
130
+ to agent_1 holds for an item and we should collect agent_0's
131
+ proposed quantity for that item.
132
+ opposite_comparator: inverse relation used to collect agent_1's items.
133
+
134
+ Behavior:
135
+ - Executes only on final timestep of a round (where the definitive
136
+ proposal / allocation is known via ``info['splits']``).
137
+ - For each item, classifies which agent's value satisfies the chosen
138
+ relation and records that agent's proposed quantity from the split.
139
+ - Averages (mean) across all qualifying items per agent; if no items
140
+ qualify for an agent returns ``None`` for that agent id.
141
+ - Adds ``all_agents`` mean across the numeric (non-None) agent values.
142
+
143
+ Why this matters:
144
+ Distinguishing how much an agent *asks for* when it subjectively
145
+ values items more (or less) than its counterpart reveals patterns of
146
+ opportunism vs. concession. This is especially useful when raw reward
147
+ differences are subtle but allocation *intent* differs.
148
+ """
149
+ info = sl.info or {}
150
+ if not info or not info.get("is_last_timestep_in_round"):
151
+ return None
152
+ quantities = info.get("quantities") or {}
153
+ splits = info.get("splits") or {}
154
+ values = info.get("values") or {}
155
+ agent_ids: List[str] = list(sl.rewards.keys())
156
+ if len(agent_ids) != 2:
157
+ return None # Only defined for 2-agent case.
158
+ for aid in agent_ids:
159
+ if "buffer" in str(aid) and "live" not in str(aid):
160
+ return None
161
+ # Extract per-agent item proposals robustly
162
+ split_items = {aid: _extract_items_from_split(splits.get(aid)) for aid in agent_ids}
163
+ agent_0_vals: List[float] = []
164
+ agent_1_vals: List[float] = []
165
+ for item in quantities.keys():
166
+ # Values may be either a float (same for all items) or dict per item
167
+ v0_raw = values[agent_ids[0]]
168
+ v1_raw = values[agent_ids[1]]
169
+ v0 = float(v0_raw[item]) if isinstance(v0_raw, dict) else float(v0_raw)
170
+ v1 = float(v1_raw[item]) if isinstance(v1_raw, dict) else float(v1_raw)
171
+ if comparator(v0, v1):
172
+ agent_0_vals.append(split_items[agent_ids[0]].get(item, 0.0))
173
+ elif opposite_comparator(v0, v1):
174
+ agent_1_vals.append(split_items[agent_ids[1]].get(item, 0.0))
175
+ out: Dict[str, float | None] = {}
176
+ out[f"{metric_name}-{agent_ids[0]}"] = (
177
+ sum(agent_0_vals) / len(agent_0_vals) if agent_0_vals else None
178
+ )
179
+ out[f"{metric_name}-{agent_ids[1]}"] = (
180
+ sum(agent_1_vals) / len(agent_1_vals) if agent_1_vals else None
181
+ )
182
+
183
+ return [(key, value) for key, value in out.items() if value is not None]
184
+
185
+
186
+ def average_proposal_when_agent_values_item_lower(
187
+ sl: SimulationStepLog,
188
+ ) -> List[Tuple[str, float | None]] | None:
189
+ """Mean quantity an agent proposes for items it values *less* than opponent.
190
+
191
+ Interpretation:
192
+ A higher value implies the agent still claims (or is allocated) a
193
+ notable share of items where it has a comparative *disadvantage* in
194
+ valuation, signaling either strategic over-claiming or protocol-driven
195
+ egalitarian splits. Conversely, very low numbers can indicate
196
+ efficient specialization or excessive concession.
197
+
198
+ Returns:
199
+ Mapping { agent_id: float | None, "all_agents": float | None } where
200
+ None indicates no qualifying items for that agent in the round.
201
+ """
202
+ return _average_proposal_relative_value(
203
+ sl,
204
+ "average_proposal_when_agent_values_item_lower",
205
+ lambda a, b: a < b,
206
+ lambda a, b: a > b,
207
+ )
208
+
209
+
210
+ def average_proposal_when_agent_values_item_higher(
211
+ sl: SimulationStepLog,
212
+ ) -> List[Tuple[str, float | None]] | None:
213
+ """Mean quantity an agent proposes for items it values *more* than opponent.
214
+
215
+ Interpretation:
216
+ Captures how aggressively an agent claims items where it holds a
217
+ comparative *advantage*. Elevated values can reflect rational
218
+ specialization (efficient exploitation of comparative advantage) or
219
+ potentially unfair grabs if paired with low concession in the lower
220
+ valuation metric. Comparing this with the 'lower' counterpart helps
221
+ profile negotiation style (cooperative vs. exploitative).
222
+
223
+ Returns:
224
+ Mapping { agent_id: float | None, "all_agents": float | None } where
225
+ None indicates no qualifying items.
226
+ """
227
+ return _average_proposal_relative_value(
228
+ sl,
229
+ "average_proposal_when_agent_values_item_higher",
230
+ lambda a, b: a > b,
231
+ lambda a, b: a < b,
232
+ )
233
+
234
+
235
+ # Explicit list of metric functions exported for rendering. Helper functions
236
+ # starting with '_' are intentionally excluded. Update this list when adding
237
+ # new public statistics so render.py can rely on it instead of introspecting
238
+ # every callable in the module.
239
+ stat_functs: list[Callable[[SimulationStepLog], List[Tuple[str, float]]]] = [
240
+ avg_reward,
241
+ average_proposal_when_agent_values_item_lower,
242
+ average_proposal_when_agent_values_item_higher,
243
+ split_efficiency,
244
+ ]
src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc ADDED
Binary file (3.21 kB). View file
 
src_code_for_reproducibility/training/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (155 Bytes). View file
 
src_code_for_reproducibility/training/__pycache__/produce_training_stats.cpython-312.pyc ADDED
Binary file (15.4 kB). View file
 
src_code_for_reproducibility/training/__pycache__/tally_metrics.cpython-312.pyc ADDED
Binary file (3.09 kB). View file
 
src_code_for_reproducibility/training/__pycache__/tally_tokenwise.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
src_code_for_reproducibility/training/__pycache__/tokenize_chats.cpython-312.pyc ADDED
Binary file (5.85 kB). View file
 
src_code_for_reproducibility/training/__pycache__/trainer_ad_align.cpython-312.pyc ADDED
Binary file (19.5 kB). View file
 
src_code_for_reproducibility/training/__pycache__/trainer_common.cpython-312.pyc ADDED
Binary file (39.7 kB). View file
 
src_code_for_reproducibility/training/tally_tokenwise.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Any, Dict, List, Tuple, Union
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from transformers import AutoTokenizer
9
+
10
+
11
+ class ContextualizedTokenwiseTally:
12
+ """
13
+ Collect, store, and save token-level metrics per rollout.
14
+
15
+ - One DataFrame per rollout_id in `paths`
16
+ - Index = timestep (int)
17
+ - Columns are added incrementally via `add_contexts()` and `add_data()`
18
+ - Cells may contain scalars, strings, or lists (dtype=object)
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ tokenizer: AutoTokenizer,
24
+ paths: List[str],
25
+ max_context_length: int = 30,
26
+ ):
27
+ """
28
+ Args:
29
+ tokenizer: HuggingFace tokenizer used to convert tids -> tokens
30
+ paths: rollout identifiers (parallel to batch dimension)
31
+ max_context_length: truncate context token lists to this length
32
+ """
33
+ self.tokenizer = tokenizer
34
+ self.paths = paths
35
+ self.max_context_length = max_context_length
36
+ self.tally: Dict[str, pd.DataFrame] = {path: pd.DataFrame() for path in paths}
37
+
38
+ # set later by setters
39
+ self.contexts: torch.Tensor | None = None
40
+ self.action_mask: torch.Tensor | None = None
41
+ self.range: Tuple[int, int] | None = None
42
+
43
+ # --------- Utilities ---------
44
+
45
+ def tids_to_str(self, tids: List[int]) -> List[str]:
46
+ """Convert a list of token IDs to a list of token strings."""
47
+ return self.tokenizer.convert_ids_to_tokens(tids)
48
+
49
+ def _ensure_ready(self):
50
+ assert self.action_mask is not None, "call set_action_mask(mask) first"
51
+ assert self.range is not None, "call set_range((start, end)) first"
52
+
53
+ @staticmethod
54
+ def _sanitize_filename(name: Any) -> str:
55
+ """Make a safe filename from any rollout_id."""
56
+ s = str(name)
57
+ bad = {os.sep, " ", ":", "|", "<", ">", '"', "'"}
58
+ if os.altsep is not None:
59
+ bad.add(os.altsep)
60
+ for ch in bad:
61
+ s = s.replace(ch, "_")
62
+ return s
63
+
64
+ @staticmethod
65
+ def _pad_left(seq: List[Any], length: int, pad_val: Any = "") -> List[Any]:
66
+ """Left-pad a sequence to `length` with `pad_val`."""
67
+ if len(seq) >= length:
68
+ return seq[-length:]
69
+ return [pad_val] * (length - len(seq)) + list(seq)
70
+
71
+ # --------- Setters ---------
72
+
73
+ def set_action_mask(self, action_mask: torch.Tensor):
74
+ """
75
+ action_mask: (B, S) bool or 0/1 indicating valid steps
76
+ """
77
+ self.action_mask = action_mask
78
+
79
+ def set_range(self, range: Tuple[int, int]):
80
+ """
81
+ range: slice (start, end) into self.paths for current batch
82
+ """
83
+ self.range = range
84
+
85
+ # --------- Column builders ---------
86
+
87
+ def add_contexts(self, contexts: torch.Tensor):
88
+ """
89
+ Add a single 'context' column (list[str]) for valid steps.
90
+
91
+ Expects `contexts` with shape (B, S): token id at each timestep.
92
+ For each valid timestep t, we use the last N tokens up to and including t:
93
+ window = contexts[i, max(0, t - N + 1) : t + 1]
94
+ The list is left-padded with "" to always be length N.
95
+ """
96
+ self._ensure_ready()
97
+
98
+ current_paths = self.paths[self.range[0] : self.range[1]]
99
+ B, S = contexts.shape
100
+ N = self.max_context_length
101
+
102
+ # to CPU ints once
103
+ contexts_cpu = contexts.detach().to("cpu")
104
+
105
+ for i in range(B):
106
+ rollout_id = current_paths[i]
107
+ df = self.tally.get(rollout_id, pd.DataFrame())
108
+
109
+ valid_idx = torch.nonzero(
110
+ self.action_mask[i].bool(), as_tuple=False
111
+ ).squeeze(-1)
112
+ if valid_idx.numel() == 0:
113
+ self.tally[rollout_id] = df
114
+ continue
115
+
116
+ idx_list = valid_idx.tolist()
117
+
118
+ # ensure index contains valid steps
119
+ if df.empty:
120
+ df = pd.DataFrame(index=idx_list)
121
+ else:
122
+ new_index = sorted(set(df.index.tolist()) | set(idx_list))
123
+ if list(df.index) != new_index:
124
+ df = df.reindex(new_index)
125
+
126
+ # build context windows
127
+ ctx_token_lists = []
128
+ for t in idx_list:
129
+ start = max(0, t - N + 1)
130
+ window_ids = contexts_cpu[i, start : t + 1].tolist()
131
+ window_toks = self.tids_to_str([int(x) for x in window_ids])
132
+ if len(window_toks) < N:
133
+ window_toks = [""] * (N - len(window_toks)) + window_toks
134
+ else:
135
+ window_toks = window_toks[-N:]
136
+ ctx_token_lists.append(window_toks)
137
+
138
+ # single 'context' column
139
+ if "context" not in df.columns:
140
+ df["context"] = pd.Series(index=df.index, dtype=object)
141
+ df.loc[idx_list, "context"] = pd.Series(
142
+ ctx_token_lists, index=idx_list, dtype=object
143
+ )
144
+
145
+ self.tally[rollout_id] = df
146
+
147
+ def add_data(
148
+ self,
149
+ metric_id: str,
150
+ metrics: torch.Tensor,
151
+ to_tids: bool = False,
152
+ ):
153
+ """
154
+ Add a metric column for valid steps.
155
+
156
+ Args:
157
+ metric_id: column name
158
+ metrics: shape (B, S) for scalars/ids or (B, S, K) for top-k vectors
159
+ to_tids: if True, treat ints/lists of ints as tids and convert to tokens
160
+ """
161
+ self._ensure_ready()
162
+ current_paths = self.paths[self.range[0] : self.range[1]]
163
+
164
+ if metrics.dim() == 2:
165
+ B, S = metrics.shape
166
+ elif metrics.dim() == 3:
167
+ B, S, _ = metrics.shape
168
+ else:
169
+ raise ValueError("metrics must be (B, S) or (B, S, K)")
170
+
171
+ for i in range(B):
172
+ rollout_id = current_paths[i]
173
+ df = self.tally.get(rollout_id, pd.DataFrame())
174
+
175
+ valid_idx = torch.nonzero(
176
+ self.action_mask[i].bool(), as_tuple=False
177
+ ).squeeze(-1)
178
+ if valid_idx.numel() == 0:
179
+ self.tally[rollout_id] = df
180
+ continue
181
+
182
+ idx_list = valid_idx.detach().cpu().tolist()
183
+
184
+ # Ensure index contains valid steps
185
+ if df.empty:
186
+ df = pd.DataFrame(index=idx_list)
187
+ else:
188
+ new_index = sorted(set(df.index.tolist()) | set(idx_list))
189
+ if list(df.index) != new_index:
190
+ df = df.reindex(new_index)
191
+
192
+ # Slice metrics at valid steps
193
+ m_valid = metrics[i][valid_idx]
194
+
195
+ # -> pure python lists (1D list or list-of-lists)
196
+ values = m_valid.detach().cpu().tolist()
197
+
198
+ # optional tids -> tokens
199
+ if to_tids:
200
+
201
+ def _to_tokish(x):
202
+ if isinstance(x, list):
203
+ return self.tids_to_str([int(v) for v in x])
204
+ else:
205
+ return self.tids_to_str([int(x)])[0]
206
+
207
+ values = [_to_tokish(v) for v in values]
208
+
209
+ # Ensure column exists with object dtype, then assign via aligned Series
210
+ if metric_id not in df.columns:
211
+ df[metric_id] = pd.Series(index=df.index, dtype=object)
212
+
213
+ if isinstance(values, np.ndarray):
214
+ values = values.tolist()
215
+
216
+ if len(values) != len(idx_list):
217
+ raise ValueError(
218
+ f"Length mismatch for '{metric_id}': values={len(values)} vs idx_list={len(idx_list)}"
219
+ )
220
+
221
+ df.loc[idx_list, metric_id] = pd.Series(
222
+ values, index=idx_list, dtype=object
223
+ )
224
+ self.tally[rollout_id] = df
225
+
226
+ # --------- Saving ---------
227
+
228
+ def save(self, path: str):
229
+ """
230
+ Write a manifest JSON and one CSV per rollout.
231
+
232
+ - Manifest includes metadata only (safe to JSON).
233
+ - Each rollout CSV is written with index label 'timestep'.
234
+ - Only a single 'context' column (list[str]).
235
+ """
236
+ if not self.tally or all(df.empty for df in self.tally.values()):
237
+ return
238
+
239
+ os.makedirs(path, exist_ok=True)
240
+ from datetime import datetime
241
+
242
+ now = datetime.now()
243
+
244
+ manifest = {
245
+ "created_at": f"{now:%Y-%m-%d %H:%M:%S}",
246
+ "max_context_length": self.max_context_length,
247
+ "num_rollouts": len(self.tally),
248
+ "rollouts": [],
249
+ }
250
+
251
+ for rid, df in self.tally.items():
252
+ rid_str = str(rid)
253
+ safe_name = self._sanitize_filename(rid_str)
254
+ csv_path = os.path.join(path, f"{safe_name}_tokenwise.csv")
255
+
256
+ # Put 'context' first, then the rest
257
+ cols = ["context"] + [c for c in df.columns if c != "context"]
258
+ try:
259
+ df[cols].to_csv(csv_path, index=True, index_label="timestep")
260
+ except Exception as e:
261
+ continue
262
+
263
+ manifest["rollouts"].append(
264
+ {
265
+ "rollout_id": rid_str,
266
+ "csv": csv_path,
267
+ "num_rows": int(df.shape[0]),
268
+ "columns": cols,
269
+ }
270
+ )
271
+
272
+ manifest_path = os.path.join(
273
+ path, f"tokenwise_manifest_{now:%Y-%m-%d___%H-%M-%S}.json"
274
+ )
275
+ with open(manifest_path, "w") as fp:
276
+ json.dump(manifest, fp, indent=2)
src_code_for_reproducibility/training/tokenize_chats.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+
4
+ import regex
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+
8
+ from mllm.training.training_data_utils import TrainingChatTurn, TrajectoryBatch
9
+
10
+ logger = logging.getLogger(__name__)
11
+ logger.addHandler(logging.StreamHandler(sys.stdout))
12
+
13
+
14
+ # def get_chat_dicts(chat: list[TrainingChatTurn]) -> list[dict]:
15
+ # chat_dicts = [chat_turn.dict() for chat_turn in chat]
16
+ # return chat_dicts
17
+
18
+
19
+ def process_training_chat(
20
+ tokenizer: AutoTokenizer,
21
+ chat_history: list[TrainingChatTurn],
22
+ entropy_mask_regex: str | None = None,
23
+ exploration_prompts_to_remove: list[str] = [],
24
+ use_engine_out_token_ids: bool = False,
25
+ ) -> tuple[torch.IntTensor, torch.BoolTensor, torch.IntTensor, torch.BoolTensor]:
26
+ """Tokenize a single training chat and build aligned per-token masks.
27
+
28
+ Given an ordered list of `TrainingChatTurn`, this function tokenizes each
29
+ turn independently using the tokenizer's chat template, then concatenates
30
+ all resulting token sequences. It also constructs three parallel 1D masks
31
+ that align with the concatenated tokens:
32
+
33
+ - input_ids: token ids for the entire chat, turn by turn
34
+ - action_mask: True for tokens that belong to assistant turns (i.e., model
35
+ actions), False for tokens from other roles
36
+ - timesteps: per-token time step copied from the originating turn's
37
+ `time_step`
38
+ - state_ends_mask: True for the last token of any turn where
39
+ `is_state_end` is True, otherwise False
40
+
41
+ Important details:
42
+ - Each turn is passed as a single-message list to
43
+ `tokenizer.apply_chat_template` and flattened; the per-turn outputs are
44
+ then concatenated in the original order.
45
+ - Turn boundaries are not explicitly encoded beyond what the chat template
46
+ inserts; masks provide alignment for learning signals and state endings.
47
+ - No truncation or padding is performed here; downstream code should handle
48
+ batching/padding as needed.
49
+ - Note on dtypes: `input_ids` will be a LongTensor (int64). `action_mask`
50
+ and `state_ends_mask` are BoolTensors. `timesteps` is currently created
51
+ as a float tensor; adjust the implementation if integer dtype is
52
+ required downstream.
53
+
54
+ Args:
55
+ tokenizer: A Hugging Face tokenizer supporting `apply_chat_template`.
56
+ chat_history: Ordered list of `TrainingChatTurn` forming one dialogue.
57
+
58
+ Returns:
59
+ A tuple of four 1D tensors, all of equal length N (the total number of
60
+ tokens across all turns), in the following order:
61
+ - input_ids (LongTensor)
62
+ - action_mask (BoolTensor)
63
+ - timesteps (FloatTensor as implemented; see note above)
64
+ - state_ends_mask (BoolTensor)
65
+ """
66
+ state_ends_mask = []
67
+ input_ids = []
68
+ action_mask = []
69
+ timesteps = []
70
+ entropy_mask = []
71
+ engine_log_probs = []
72
+ for train_chat_turn in chat_history:
73
+ is_state_end = train_chat_turn.is_state_end
74
+ time_step = train_chat_turn.time_step
75
+ is_action = train_chat_turn.role == "assistant"
76
+
77
+ # Remove exploration prompts from training data
78
+ for exploration_prompt in exploration_prompts_to_remove:
79
+ if exploration_prompt in train_chat_turn.content:
80
+ train_chat_turn.content = train_chat_turn.content.replace(
81
+ exploration_prompt, ""
82
+ )
83
+
84
+ chat_turn = {
85
+ "role": train_chat_turn.role,
86
+ "content": train_chat_turn.content,
87
+ }
88
+ if entropy_mask_regex is not None:
89
+ is_entropy_mask_true = (
90
+ regex.search(entropy_mask_regex, train_chat_turn.content) is not None
91
+ )
92
+ else:
93
+ is_entropy_mask_true = True
94
+ if is_action:
95
+ chat_turn_ids = train_chat_turn.out_token_ids
96
+ nb_chat_turns_ids = chat_turn_ids.numel()
97
+ action_mask.append(torch.ones(nb_chat_turns_ids, dtype=torch.bool))
98
+ engine_log_probs.append(train_chat_turn.log_probs)
99
+ else:
100
+ chat_turn_ids = train_chat_turn.chat_template_token_ids
101
+ nb_chat_turns_ids = chat_turn_ids.numel()
102
+ action_mask.append(torch.zeros(nb_chat_turns_ids, dtype=torch.bool))
103
+ engine_log_probs.append(torch.zeros(nb_chat_turns_ids, dtype=torch.float))
104
+ nb_chat_turns_ids = chat_turn_ids.numel()
105
+ state_ends_mask.append(torch.zeros(nb_chat_turns_ids, dtype=torch.bool))
106
+ if is_state_end:
107
+ state_ends_mask[-1][-1] = True # last token is state end
108
+ input_ids.append(chat_turn_ids)
109
+ entropy_mask.append(torch.ones(nb_chat_turns_ids, dtype=torch.bool))
110
+ if not is_entropy_mask_true:
111
+ entropy_mask[-1] = entropy_mask[-1] * False
112
+ timesteps.append(torch.ones(nb_chat_turns_ids) * time_step)
113
+ input_ids = torch.cat(input_ids)
114
+ action_mask = torch.cat(action_mask)
115
+ entropy_mask = torch.cat(entropy_mask)
116
+ timesteps = torch.cat(timesteps)
117
+ timesteps = timesteps.to(torch.long)
118
+ state_ends_mask = torch.cat(state_ends_mask)
119
+ engine_log_probs = torch.cat(engine_log_probs)
120
+
121
+ return (
122
+ input_ids,
123
+ action_mask,
124
+ entropy_mask,
125
+ timesteps,
126
+ state_ends_mask,
127
+ engine_log_probs,
128
+ )
src_code_for_reproducibility/training/trainer_sum_rewards.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ """
4
+ import logging
5
+ import os
6
+ import sys
7
+ from typing import Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from accelerate import Accelerator
12
+ from pandas._libs.tslibs.offsets import CBMonthBegin
13
+ from peft import LoraConfig
14
+ from torch.nn.utils.rnn import pad_sequence
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
+
17
+ from mllm.markov_games.rollout_tree import *
18
+ from mllm.markov_games.rollout_tree import RolloutTreeRootNode
19
+ from mllm.training.credit_methods import (
20
+ get_discounted_returns,
21
+ get_discounted_state_visitation_credits,
22
+ get_generalized_advantage_estimates,
23
+ get_rloo_credits,
24
+ )
25
+ from mllm.training.tally_metrics import Tally
26
+ from mllm.training.tally_rollout import RolloutTally, RolloutTallyItem
27
+ from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally
28
+ from mllm.training.tokenize_chats import *
29
+ from mllm.training.tokenize_chats import process_training_chat
30
+ from mllm.training.trainer_common import BaseTrainer
31
+ from mllm.training.trainer_independent import TrainerNaive, TrainingData
32
+ from mllm.training.training_data_utils import *
33
+ from mllm.training.training_data_utils import (
34
+ AdvantagePacket,
35
+ TrainingBatch,
36
+ TrajectoryBatch,
37
+ get_tokenwise_credits,
38
+ )
39
+ from mllm.utils.resource_context import resource_logger_context
40
+
41
+ logger = logging.getLogger(__name__)
42
+ logger.addHandler(logging.StreamHandler(sys.stdout))
43
+
44
+
45
+ class TrainerSumRewards(TrainerNaive):
46
+ def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]):
47
+ """
48
+ Sums the advantages of the other trainers
49
+ """
50
+ logger.info(f"Receiving advantage packets.")
51
+
52
+ assert (
53
+ len(advantage_packets) > 0
54
+ ), "At least one advantage packet must be provided."
55
+
56
+ for agent_id, agent_data in self.training_data.items():
57
+ coagent_advantage_packets = [
58
+ packet for packet in advantage_packets if packet.agent_id != agent_id
59
+ ]
60
+ agent_rollout_ids = agent_data.main_data.rollout_ids
61
+ agent_advantages = agent_data.main_advantages
62
+ co_agent_advantages = []
63
+ for rollout_id in agent_rollout_ids:
64
+ for co_agent_packet in coagent_advantage_packets:
65
+ if rollout_id in co_agent_packet.rollout_ids:
66
+ index = torch.where(rollout_id == co_agent_packet.rollout_ids)[
67
+ 0
68
+ ].item()
69
+ co_agent_advantages.append(
70
+ co_agent_packet.main_advantages[index]
71
+ )
72
+ # assumes that its two player game, with one co-agent
73
+ break
74
+ assert len(co_agent_advantages) == len(agent_advantages)
75
+ B = len(agent_advantages)
76
+ assert all(
77
+ a.shape[0] == b.shape[0]
78
+ for a, b in zip(co_agent_advantages, agent_advantages)
79
+ ), "Number of advantages must match in order to sum them up."
80
+
81
+ # Get padded tensors (advantage alignment is invariant to padding)
82
+ lengths = torch.tensor(
83
+ [len(t) for t in agent_advantages],
84
+ device=self.device,
85
+ dtype=torch.long,
86
+ )
87
+ padded_main_advantages = pad_sequence(
88
+ agent_advantages, batch_first=True, padding_value=0.0
89
+ )
90
+
91
+ padded_co_agent_advantages = pad_sequence(
92
+ co_agent_advantages, batch_first=True, padding_value=0.0
93
+ )
94
+
95
+ # Create training batch data
96
+ sum_of_ad_credits = padded_main_advantages + padded_co_agent_advantages
97
+ self.rollout_tally.add_metric(
98
+ path=["sum_of_ad_credits"],
99
+ rollout_tally_item=RolloutTallyItem(
100
+ crn_ids=agent_data.main_data.crn_ids,
101
+ rollout_ids=agent_data.main_data.rollout_ids,
102
+ agent_ids=agent_data.main_data.agent_ids,
103
+ metric_matrix=sum_of_ad_credits,
104
+ ),
105
+ )
106
+
107
+ if not self.skip_discounted_state_visitation:
108
+ sum_of_ad_credits = get_discounted_state_visitation_credits(
109
+ sum_of_ad_credits,
110
+ self.discount_factor,
111
+ )
112
+ self.rollout_tally.add_metric(
113
+ path=["discounted_state_visitation_credits"],
114
+ rollout_tally_item=RolloutTallyItem(
115
+ crn_ids=agent_data.main_data.crn_ids,
116
+ rollout_ids=agent_data.main_data.rollout_ids,
117
+ agent_ids=agent_data.main_data.agent_ids,
118
+ metric_matrix=sub_tensors[
119
+ "discounted_state_visitation_credits"
120
+ ],
121
+ ),
122
+ )
123
+
124
+ # Slice back to jagged and convert to tokenwise credits
125
+ sum_of_ad_credits = [sum_of_ad_credits[i, : lengths[i]] for i in range(B)]
126
+ self.training_data[agent_id] = agent_data.main_data
127
+ self.training_data[agent_id].batch_credits = sum_of_ad_credits
src_code_for_reproducibility/utils/__pycache__/get_coagent_id.cpython-312.pyc ADDED
Binary file (422 Bytes). View file
 
src_code_for_reproducibility/utils/__pycache__/resource_context.cpython-312.pyc ADDED
Binary file (4.55 kB). View file
 
src_code_for_reproducibility/utils/get_coagent_id.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ def get_coagent_id(ids: list[str], agent_id:str) -> str | None:
3
+ for id in ids:
4
+ if id != agent_id: return id