jtowarek commited on
Commit
3ff9218
·
verified ·
1 Parent(s): 9ab079e

Upload folder using huggingface_hub

Browse files
Dockerfile CHANGED
@@ -2,7 +2,7 @@ FROM python:3.11-slim
2
 
3
  WORKDIR /app
4
 
5
- RUN pip install --no-cache-dir gradio pydantic
6
 
7
  COPY . /app
8
 
 
2
 
3
  WORKDIR /app
4
 
5
+ RUN pip install --no-cache-dir gradio pydantic anthropic openai
6
 
7
  COPY . /app
8
 
train/Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/pytorch:24.08-py3
2
+
3
+ WORKDIR /workspace
4
+
5
+ # Install dependencies
6
+ COPY requirements.txt .
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ # Copy training script
10
+ COPY train.py .
11
+
12
+ # Default: train with Qwen2.5-7B-Instruct, 500 steps
13
+ CMD ["python", "train.py", \
14
+ "--model", "Qwen/Qwen2.5-7B-Instruct", \
15
+ "--episodes", "2000", \
16
+ "--max-steps", "500", \
17
+ "--num-generations", "8", \
18
+ "--batch-size", "2", \
19
+ "--grad-accum", "8", \
20
+ "--output-dir", "/workspace/output"]
train/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training pipeline for strategic reasoning via game-theory environments."""
2
+
3
+ __all__ = [
4
+ "LLMAgent",
5
+ "PromptBuilder",
6
+ "parse_action",
7
+ "episode_reward",
8
+ "get_train_eval_split",
9
+ "EpisodeTrajectory",
10
+ "StepRecord",
11
+ "TrajectoryCollector",
12
+ ]
13
+
14
+
15
+ def __getattr__(name: str) -> object:
16
+ """Lazy imports to avoid pulling in openenv at package load time."""
17
+ if name in ("LLMAgent", "PromptBuilder", "parse_action"):
18
+ from train.agent import LLMAgent, PromptBuilder, parse_action
19
+ _map = {
20
+ "LLMAgent": LLMAgent,
21
+ "PromptBuilder": PromptBuilder,
22
+ "parse_action": parse_action,
23
+ }
24
+ return _map[name]
25
+ if name == "episode_reward":
26
+ from train.rewards import episode_reward
27
+ return episode_reward
28
+ if name == "get_train_eval_split":
29
+ from train.splits import get_train_eval_split
30
+ return get_train_eval_split
31
+ if name in ("EpisodeTrajectory", "StepRecord", "TrajectoryCollector"):
32
+ from train.trajectory import (
33
+ EpisodeTrajectory, StepRecord, TrajectoryCollector,
34
+ )
35
+ _map = {
36
+ "EpisodeTrajectory": EpisodeTrajectory,
37
+ "StepRecord": StepRecord,
38
+ "TrajectoryCollector": TrajectoryCollector,
39
+ }
40
+ return _map[name]
41
+ msg = f"module 'train' has no attribute {name!r}"
42
+ raise AttributeError(msg)
train/agent.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM agent for game-theory environments."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+ from typing import Any, Callable, Dict, List, Optional
7
+
8
+ from env.models import GameAction, GameObservation
9
+ from constant_definitions.train.agent_constants import (
10
+ MAX_ACTION_TOKENS,
11
+ MAX_PROMPT_HISTORY_ROUNDS,
12
+ PARSE_FAILURE_SENTINEL,
13
+ PROMPT_SECTION_ACTIONS,
14
+ PROMPT_SECTION_GAME,
15
+ PROMPT_SECTION_HISTORY,
16
+ PROMPT_SECTION_INSTRUCTION,
17
+ PROMPT_SECTION_SCORES,
18
+ SYSTEM_PROMPT,
19
+ TRAIN_TEMPERATURE_DENOMINATOR,
20
+ TRAIN_TEMPERATURE_NUMERATOR,
21
+ )
22
+
23
+ _ZERO = int()
24
+ _ONE = int(bool(True))
25
+ _NEWLINE = "\n"
26
+ _SECTION_SEP = "\n\n"
27
+ _BRACKET_OPEN = "["
28
+ _BRACKET_CLOSE = "]"
29
+ _COLON_SPACE = ": "
30
+ _DASH_SPACE = "- "
31
+ _ROUND_PREFIX = "Round "
32
+ _YOU_PLAYED = " | You played: "
33
+ _OPP_PLAYED = " | Opponent played: "
34
+ _YOUR_PAYOFF = " | Your payoff: "
35
+ _OPP_PAYOFF = " | Opp payoff: "
36
+
37
+
38
+ class PromptBuilder:
39
+ """Formats GameObservation into a structured text prompt.
40
+
41
+ The prompt intentionally excludes the opponent strategy name
42
+ to prevent the model from shortcutting via strategy recognition.
43
+ """
44
+
45
+ @staticmethod
46
+ def build(obs: GameObservation) -> str:
47
+ """Build a structured prompt from a game observation."""
48
+ sections: List[str] = []
49
+
50
+ # Game section
51
+ sections.append(
52
+ _BRACKET_OPEN + PROMPT_SECTION_GAME + _BRACKET_CLOSE
53
+ + _NEWLINE + obs.game_name
54
+ + _NEWLINE + obs.game_description
55
+ )
56
+
57
+ # History section (limited to last N rounds)
58
+ if obs.history:
59
+ history_lines: List[str] = []
60
+ history_slice = obs.history[-MAX_PROMPT_HISTORY_ROUNDS:]
61
+ for rnd in history_slice:
62
+ line = (
63
+ _ROUND_PREFIX + str(rnd.round_number)
64
+ + _YOU_PLAYED + rnd.player_action
65
+ + _OPP_PLAYED + rnd.opponent_action
66
+ + _YOUR_PAYOFF + str(rnd.player_payoff)
67
+ + _OPP_PAYOFF + str(rnd.opponent_payoff)
68
+ )
69
+ history_lines.append(line)
70
+ sections.append(
71
+ _BRACKET_OPEN + PROMPT_SECTION_HISTORY + _BRACKET_CLOSE
72
+ + _NEWLINE + _NEWLINE.join(history_lines)
73
+ )
74
+
75
+ # Scores section
76
+ sections.append(
77
+ _BRACKET_OPEN + PROMPT_SECTION_SCORES + _BRACKET_CLOSE
78
+ + _NEWLINE + "Your score" + _COLON_SPACE + str(obs.player_score)
79
+ + _NEWLINE + "Opponent score" + _COLON_SPACE + str(obs.opponent_score)
80
+ + _NEWLINE + "Round" + _COLON_SPACE + str(obs.current_round)
81
+ + " of " + str(obs.total_rounds)
82
+ )
83
+
84
+ # Available actions
85
+ action_lines = [_DASH_SPACE + a for a in obs.available_actions]
86
+ sections.append(
87
+ _BRACKET_OPEN + PROMPT_SECTION_ACTIONS + _BRACKET_CLOSE
88
+ + _NEWLINE + _NEWLINE.join(action_lines)
89
+ )
90
+
91
+ # Instruction
92
+ sections.append(
93
+ _BRACKET_OPEN + PROMPT_SECTION_INSTRUCTION + _BRACKET_CLOSE
94
+ + _NEWLINE + SYSTEM_PROMPT
95
+ )
96
+
97
+ return _SECTION_SEP.join(sections)
98
+
99
+
100
+ def parse_action(response: str, available_actions: List[str]) -> str:
101
+ """Parse an action from LLM response text.
102
+
103
+ Tries: exact match -> case-insensitive -> substring -> random selection.
104
+ """
105
+ stripped = response.strip()
106
+
107
+ # Exact match
108
+ if stripped in available_actions:
109
+ return stripped
110
+
111
+ # Case-insensitive match
112
+ lower = stripped.lower()
113
+ for action in available_actions:
114
+ if action.lower() == lower:
115
+ return action
116
+
117
+ # Substring match (response contains action name)
118
+ for action in available_actions:
119
+ if action.lower() in lower:
120
+ return action
121
+
122
+ # Random selection as last resort
123
+ return random.choice(available_actions)
124
+
125
+
126
+ class LLMAgent:
127
+ """LLM-based agent compatible with TournamentRunner agent_fn interface.
128
+
129
+ Parameters
130
+ ----------
131
+ generate_fn : callable
132
+ A function that takes a prompt string and returns a completion string.
133
+ This abstracts over different model backends (HF, vLLM, API).
134
+ prompt_builder : PromptBuilder, optional
135
+ Custom prompt builder. Defaults to the standard PromptBuilder.
136
+ """
137
+
138
+ def __init__(
139
+ self,
140
+ generate_fn: Callable[[str], str],
141
+ prompt_builder: Optional[PromptBuilder] = None,
142
+ ) -> None:
143
+ self._generate_fn = generate_fn
144
+ self._prompt_builder = prompt_builder or PromptBuilder()
145
+ self._last_prompt: str = ""
146
+ self._last_completion: str = ""
147
+
148
+ def __call__(self, obs: GameObservation) -> GameAction:
149
+ """Select an action given a game observation."""
150
+ prompt = self._prompt_builder.build(obs)
151
+ self._last_prompt = prompt
152
+ completion = self._generate_fn(prompt)
153
+ self._last_completion = completion
154
+ action_str = parse_action(completion, obs.available_actions)
155
+ return GameAction(action=action_str)
156
+
157
+ @property
158
+ def last_prompt(self) -> str:
159
+ """The most recently constructed prompt."""
160
+ return self._last_prompt
161
+
162
+ @property
163
+ def last_completion(self) -> str:
164
+ """The most recent raw model completion."""
165
+ return self._last_completion
166
+
167
+
168
+ class APIAgent(LLMAgent):
169
+ """Agent that uses an external API (OpenAI/Anthropic) for generation.
170
+
171
+ Parameters
172
+ ----------
173
+ api_call_fn : callable
174
+ Function(system_prompt, user_prompt) -> str that calls the API.
175
+ """
176
+
177
+ def __init__(
178
+ self,
179
+ api_call_fn: Callable[[str, str], str],
180
+ prompt_builder: Optional[PromptBuilder] = None,
181
+ ) -> None:
182
+ def _generate(prompt: str) -> str:
183
+ return api_call_fn(SYSTEM_PROMPT, prompt)
184
+
185
+ super().__init__(generate_fn=_generate, prompt_builder=prompt_builder)
train/dpo/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """DPO (Direct Preference Optimisation) training subpackage."""
2
+
3
+ from train.dpo.config import DPOConfig
4
+ from train.dpo.pairs import generate_preference_pairs
5
+ from train.dpo.trainer import KantDPOTrainer
6
+
7
+ __all__ = ["DPOConfig", "generate_preference_pairs", "KantDPOTrainer"]
train/dpo/config.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DPO training configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ from constant_definitions.train.dpo_constants import (
8
+ DPO_BATCH_SIZE,
9
+ DPO_BETA_DENOMINATOR,
10
+ DPO_BETA_NUMERATOR,
11
+ DPO_GRADIENT_ACCUMULATION_STEPS,
12
+ DPO_LR_DENOMINATOR,
13
+ DPO_LR_NUMERATOR,
14
+ DPO_MAX_LENGTH,
15
+ DPO_MIN_REWARD_MARGIN_DENOMINATOR,
16
+ DPO_MIN_REWARD_MARGIN_NUMERATOR,
17
+ DPO_NUM_EPOCHS,
18
+ DPO_TRAJECTORIES_PER_PAIR,
19
+ DPO_WARMUP_RATIO_DENOMINATOR,
20
+ DPO_WARMUP_RATIO_NUMERATOR,
21
+ )
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class DPOConfig:
26
+ """Configuration for DPO training."""
27
+
28
+ # Core hyperparameters
29
+ beta_numerator: int = DPO_BETA_NUMERATOR
30
+ beta_denominator: int = DPO_BETA_DENOMINATOR
31
+ learning_rate_numerator: int = DPO_LR_NUMERATOR
32
+ learning_rate_denominator: int = DPO_LR_DENOMINATOR
33
+ batch_size: int = DPO_BATCH_SIZE
34
+ num_epochs: int = DPO_NUM_EPOCHS
35
+ max_length: int = DPO_MAX_LENGTH
36
+ gradient_accumulation_steps: int = DPO_GRADIENT_ACCUMULATION_STEPS
37
+
38
+ # Warmup
39
+ warmup_ratio_numerator: int = DPO_WARMUP_RATIO_NUMERATOR
40
+ warmup_ratio_denominator: int = DPO_WARMUP_RATIO_DENOMINATOR
41
+
42
+ # Pair generation
43
+ trajectories_per_pair: int = DPO_TRAJECTORIES_PER_PAIR
44
+ min_reward_margin_numerator: int = DPO_MIN_REWARD_MARGIN_NUMERATOR
45
+ min_reward_margin_denominator: int = DPO_MIN_REWARD_MARGIN_DENOMINATOR
46
+
47
+ # Model
48
+ model_name: str = ""
49
+ output_dir: str = "checkpoints/dpo"
50
+
51
+ @property
52
+ def beta(self) -> float:
53
+ """Effective beta (KL penalty coefficient)."""
54
+ return self.beta_numerator / self.beta_denominator
55
+
56
+ @property
57
+ def learning_rate(self) -> float:
58
+ """Effective learning rate."""
59
+ return self.learning_rate_numerator / self.learning_rate_denominator
60
+
61
+ @property
62
+ def warmup_ratio(self) -> float:
63
+ """Effective warmup ratio."""
64
+ return self.warmup_ratio_numerator / self.warmup_ratio_denominator
65
+
66
+ @property
67
+ def min_reward_margin(self) -> float:
68
+ """Minimum reward margin for preference pair filtering."""
69
+ return self.min_reward_margin_numerator / self.min_reward_margin_denominator
70
+
71
+ def to_trl_kwargs(self) -> dict:
72
+ """Return keyword arguments suitable for TRL DPOConfig."""
73
+ return {
74
+ "beta": self.beta,
75
+ "learning_rate": self.learning_rate,
76
+ "per_device_train_batch_size": self.batch_size,
77
+ "num_train_epochs": self.num_epochs,
78
+ "max_length": self.max_length,
79
+ "gradient_accumulation_steps": self.gradient_accumulation_steps,
80
+ "warmup_ratio": self.warmup_ratio,
81
+ "output_dir": self.output_dir,
82
+ }
train/dpo/pairs.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Preference pair generation for DPO training."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List, Tuple
6
+
7
+ from train.trajectory import EpisodeTrajectory
8
+ from constant_definitions.game_constants import EVAL_ONE, EVAL_ZERO
9
+ from constant_definitions.train.dpo_constants import (
10
+ DPO_BOTTOM_QUANTILE_DENOMINATOR,
11
+ DPO_BOTTOM_QUANTILE_NUMERATOR,
12
+ DPO_MIN_REWARD_MARGIN_DENOMINATOR,
13
+ DPO_MIN_REWARD_MARGIN_NUMERATOR,
14
+ DPO_TOP_QUANTILE_DENOMINATOR,
15
+ DPO_TOP_QUANTILE_NUMERATOR,
16
+ )
17
+
18
+ _ONE = int(bool(True))
19
+
20
+
21
+ def generate_preference_pairs(
22
+ trajectories: List[EpisodeTrajectory],
23
+ min_margin_numerator: int = DPO_MIN_REWARD_MARGIN_NUMERATOR,
24
+ min_margin_denominator: int = DPO_MIN_REWARD_MARGIN_DENOMINATOR,
25
+ ) -> List[Dict[str, Any]]:
26
+ """Generate chosen/rejected preference pairs from trajectories.
27
+
28
+ Groups trajectories by (game, strategy), ranks by episode_reward,
29
+ pairs top-quartile (chosen) vs bottom-quartile (rejected), and
30
+ filters by minimum reward margin.
31
+
32
+ Returns list of dicts with keys: prompt, chosen, rejected, margin.
33
+ """
34
+ min_margin = min_margin_numerator / min_margin_denominator
35
+
36
+ # Group by (game, strategy)
37
+ groups: Dict[Tuple[str, str], List[EpisodeTrajectory]] = {}
38
+ for traj in trajectories:
39
+ key = (traj.game, traj.strategy)
40
+ if key not in groups:
41
+ groups[key] = []
42
+ groups[key].append(traj)
43
+
44
+ pairs: List[Dict[str, Any]] = []
45
+ for _key, group in groups.items():
46
+ group_pairs = _pairs_from_group(group, min_margin)
47
+ pairs.extend(group_pairs)
48
+
49
+ return pairs
50
+
51
+
52
+ def _pairs_from_group(
53
+ group: List[EpisodeTrajectory],
54
+ min_margin: float,
55
+ ) -> List[Dict[str, Any]]:
56
+ """Generate pairs from a single (game, strategy) group."""
57
+ if len(group) < EVAL_ONE + EVAL_ONE:
58
+ return []
59
+
60
+ # Sort by episode reward descending
61
+ ranked = sorted(group, key=lambda t: t.episode_reward, reverse=True)
62
+ n = len(ranked)
63
+
64
+ # Top and bottom quartile boundaries
65
+ top_boundary = max(
66
+ _ONE,
67
+ (n * DPO_TOP_QUANTILE_NUMERATOR) // DPO_TOP_QUANTILE_DENOMINATOR,
68
+ )
69
+ bottom_boundary = max(
70
+ _ONE,
71
+ (n * DPO_BOTTOM_QUANTILE_NUMERATOR) // DPO_BOTTOM_QUANTILE_DENOMINATOR,
72
+ )
73
+
74
+ chosen_set = ranked[:top_boundary]
75
+ rejected_set = ranked[n - bottom_boundary:]
76
+
77
+ pairs: List[Dict[str, Any]] = []
78
+ for chosen in chosen_set:
79
+ for rejected in rejected_set:
80
+ margin = chosen.episode_reward - rejected.episode_reward
81
+ if margin < min_margin:
82
+ continue
83
+ # Use the full episode as prompt + chosen/rejected completions
84
+ chosen_text = _trajectory_to_text(chosen)
85
+ rejected_text = _trajectory_to_text(rejected)
86
+ prompt = _trajectory_prompt(chosen)
87
+ pairs.append({
88
+ "prompt": prompt,
89
+ "chosen": chosen_text,
90
+ "rejected": rejected_text,
91
+ "margin": margin,
92
+ "game": chosen.game,
93
+ "strategy": chosen.strategy,
94
+ })
95
+
96
+ return pairs
97
+
98
+
99
+ def _trajectory_to_text(traj: EpisodeTrajectory) -> str:
100
+ """Convert trajectory actions to a single completion string."""
101
+ return "\n".join(step.completion for step in traj.steps)
102
+
103
+
104
+ def _trajectory_prompt(traj: EpisodeTrajectory) -> str:
105
+ """Extract the first step's prompt as the shared prompt."""
106
+ if traj.steps:
107
+ return traj.steps[EVAL_ZERO].prompt
108
+ return ""
train/dpo/trainer.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DPO trainer wrapping TRL with Kant-specific preference learning."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Any, Dict, List, Optional, Sequence
7
+
8
+ from env.environment import KantEnvironment
9
+ from env.models import GameAction, GameObservation
10
+ from train.agent import LLMAgent, PromptBuilder, parse_action
11
+ from train.dpo.config import DPOConfig
12
+ from train.dpo.pairs import generate_preference_pairs
13
+ from train.splits import get_train_eval_split
14
+ from train.trajectory import EpisodeTrajectory
15
+
16
+ from constant_definitions.game_constants import EVAL_ZERO
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class KantDPOTrainer:
22
+ """DPO trainer for strategic reasoning via preference learning.
23
+
24
+ Wraps TRL's DPOTrainer with:
25
+ - Preference pair generation from trajectory rankings
26
+ - Per-checkpoint evaluation on held-out games
27
+ - Optional LoRA/QLoRA support via PEFT
28
+
29
+ Parameters
30
+ ----------
31
+ config : DPOConfig
32
+ Training configuration.
33
+ model : Any
34
+ HuggingFace model (or path to load).
35
+ tokenizer : Any
36
+ HuggingFace tokenizer.
37
+ ref_model : Any, optional
38
+ Reference model for DPO. If None, uses a copy of the policy model.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ config: DPOConfig,
44
+ model: Any = None,
45
+ tokenizer: Any = None,
46
+ ref_model: Any = None,
47
+ ) -> None:
48
+ self._config = config
49
+ self._model = model
50
+ self._tokenizer = tokenizer
51
+ self._ref_model = ref_model
52
+ self._train_games, self._eval_games = get_train_eval_split()
53
+ self._trl_trainer: Any = None
54
+
55
+ def prepare_dataset(
56
+ self,
57
+ trajectories: List[EpisodeTrajectory],
58
+ ) -> List[Dict[str, Any]]:
59
+ """Generate preference pairs from collected trajectories."""
60
+ return generate_preference_pairs(
61
+ trajectories,
62
+ min_margin_numerator=self._config.min_reward_margin_numerator,
63
+ min_margin_denominator=self._config.min_reward_margin_denominator,
64
+ )
65
+
66
+ def setup_trl_trainer(
67
+ self,
68
+ train_dataset: Any,
69
+ ) -> Any:
70
+ """Initialise the TRL DPOTrainer (requires trl to be installed)."""
71
+ try:
72
+ from trl import DPOTrainer, DPOConfig as TRLDPOConfig
73
+ except ImportError as exc:
74
+ msg = "trl is required for DPO training. Install with: pip install trl"
75
+ raise ImportError(msg) from exc
76
+
77
+ trl_config = TRLDPOConfig(**self._config.to_trl_kwargs())
78
+ self._trl_trainer = DPOTrainer(
79
+ model=self._model,
80
+ ref_model=self._ref_model,
81
+ args=trl_config,
82
+ tokenizer=self._tokenizer,
83
+ train_dataset=train_dataset,
84
+ )
85
+ return self._trl_trainer
86
+
87
+ def evaluate(
88
+ self,
89
+ games: Optional[Sequence[str]] = None,
90
+ strategies: Optional[Sequence[str]] = None,
91
+ run_external: bool = False,
92
+ external_benchmarks: Optional[Sequence[str]] = None,
93
+ ) -> Dict[str, float]:
94
+ """Run evaluation on specified games and return metric dict.
95
+
96
+ Parameters
97
+ ----------
98
+ games, strategies
99
+ Forwarded to ``TournamentRunner``.
100
+ run_external : bool
101
+ If ``True``, also run external safety benchmarks.
102
+ external_benchmarks : sequence of str, optional
103
+ Which external benchmarks to run (default: all).
104
+ """
105
+ from bench.evaluation.tournament import TournamentRunner
106
+ from bench.evaluation.metrics import compute_metrics
107
+
108
+ env = KantEnvironment()
109
+ eval_games = list(games) if games is not None else sorted(self._eval_games)
110
+
111
+ def _agent_fn(obs: GameObservation) -> GameAction:
112
+ prompt = PromptBuilder.build(obs)
113
+ if self._tokenizer is not None and self._model is not None:
114
+ inputs = self._tokenizer(prompt, return_tensors="pt")
115
+ outputs = self._model.generate(
116
+ **inputs,
117
+ max_new_tokens=self._config.max_length,
118
+ )
119
+ completion = self._tokenizer.decode(
120
+ outputs[EVAL_ZERO][len(inputs["input_ids"][EVAL_ZERO]):],
121
+ skip_special_tokens=True,
122
+ )
123
+ else:
124
+ completion = obs.available_actions[EVAL_ZERO]
125
+ action_str = parse_action(completion, obs.available_actions)
126
+ return GameAction(action=action_str)
127
+
128
+ runner = TournamentRunner(env=env, agent_fn=_agent_fn)
129
+ results = runner.run_tournament_as_dict(
130
+ games=eval_games,
131
+ strategies=strategies,
132
+ )
133
+ metrics = compute_metrics(results)
134
+
135
+ if run_external:
136
+ from bench.external._model_handle import ModelHandle
137
+ from bench.external.runner import ExternalBenchmarkRunner
138
+
139
+ handle = ModelHandle(
140
+ model_name_or_path=self._config.model_name,
141
+ model=self._model,
142
+ tokenizer=self._tokenizer,
143
+ )
144
+ ext_runner = ExternalBenchmarkRunner(
145
+ model_handle=handle,
146
+ benchmarks=external_benchmarks,
147
+ )
148
+ ext_results = ext_runner.run_all()
149
+ for bench_name, result in ext_results.items():
150
+ prefix = f"external/{bench_name}"
151
+ if result.error is not None:
152
+ metrics[f"{prefix}/error"] = True
153
+ continue
154
+ for metric_key, value in result.scores.items():
155
+ metrics[f"{prefix}/{metric_key}"] = value
156
+
157
+ return metrics
158
+
159
+ @property
160
+ def config(self) -> DPOConfig:
161
+ """Training configuration."""
162
+ return self._config
train/grpo/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """GRPO (Group Relative Policy Optimisation) training subpackage."""
2
+
3
+ from train.grpo.config import GRPOConfig
4
+ from train.grpo.dataset import trajectories_to_dataset
5
+ from train.grpo.trainer import KantGRPOTrainer
6
+
7
+ __all__ = ["GRPOConfig", "trajectories_to_dataset", "KantGRPOTrainer"]
train/grpo/config.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GRPO training configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ from constant_definitions.train.grpo_constants import (
8
+ GRPO_BATCH_SIZE,
9
+ GRPO_CHECKPOINT_EVERY,
10
+ GRPO_CURRICULUM_EXPANSION_STEP,
11
+ GRPO_CURRICULUM_INITIAL_GAMES,
12
+ GRPO_GRADIENT_ACCUMULATION_STEPS,
13
+ GRPO_LOG_EVERY,
14
+ GRPO_LR_DENOMINATOR,
15
+ GRPO_LR_NUMERATOR,
16
+ GRPO_MAX_COMPLETION_LENGTH,
17
+ GRPO_NUM_EPOCHS,
18
+ GRPO_NUM_GENERATIONS,
19
+ GRPO_SHAPING_ALPHA_DENOMINATOR,
20
+ GRPO_SHAPING_ALPHA_NUMERATOR,
21
+ GRPO_WARMUP_RATIO_DENOMINATOR,
22
+ GRPO_WARMUP_RATIO_NUMERATOR,
23
+ GRPO_WEIGHT_DECAY_DENOMINATOR,
24
+ GRPO_WEIGHT_DECAY_NUMERATOR,
25
+ )
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class GRPOConfig:
30
+ """Configuration for GRPO training."""
31
+
32
+ # Core hyperparameters (derived from constants)
33
+ learning_rate_numerator: int = GRPO_LR_NUMERATOR
34
+ learning_rate_denominator: int = GRPO_LR_DENOMINATOR
35
+ batch_size: int = GRPO_BATCH_SIZE
36
+ num_generations: int = GRPO_NUM_GENERATIONS
37
+ num_epochs: int = GRPO_NUM_EPOCHS
38
+ max_completion_length: int = GRPO_MAX_COMPLETION_LENGTH
39
+ gradient_accumulation_steps: int = GRPO_GRADIENT_ACCUMULATION_STEPS
40
+
41
+ # Warmup and regularisation
42
+ warmup_ratio_numerator: int = GRPO_WARMUP_RATIO_NUMERATOR
43
+ warmup_ratio_denominator: int = GRPO_WARMUP_RATIO_DENOMINATOR
44
+ weight_decay_numerator: int = GRPO_WEIGHT_DECAY_NUMERATOR
45
+ weight_decay_denominator: int = GRPO_WEIGHT_DECAY_DENOMINATOR
46
+
47
+ # Shaping
48
+ shaping_alpha_numerator: int = GRPO_SHAPING_ALPHA_NUMERATOR
49
+ shaping_alpha_denominator: int = GRPO_SHAPING_ALPHA_DENOMINATOR
50
+
51
+ # Scheduling
52
+ checkpoint_every: int = GRPO_CHECKPOINT_EVERY
53
+ log_every: int = GRPO_LOG_EVERY
54
+ curriculum_initial_games: int = GRPO_CURRICULUM_INITIAL_GAMES
55
+ curriculum_expansion_step: int = GRPO_CURRICULUM_EXPANSION_STEP
56
+
57
+ # Model
58
+ model_name: str = ""
59
+ output_dir: str = "checkpoints/grpo"
60
+
61
+ @property
62
+ def learning_rate(self) -> float:
63
+ """Effective learning rate as a float."""
64
+ return self.learning_rate_numerator / self.learning_rate_denominator
65
+
66
+ @property
67
+ def warmup_ratio(self) -> float:
68
+ """Effective warmup ratio."""
69
+ return self.warmup_ratio_numerator / self.warmup_ratio_denominator
70
+
71
+ @property
72
+ def weight_decay(self) -> float:
73
+ """Effective weight decay."""
74
+ return self.weight_decay_numerator / self.weight_decay_denominator
75
+
76
+ @property
77
+ def shaping_alpha(self) -> float:
78
+ """Shaping reward coefficient."""
79
+ return self.shaping_alpha_numerator / self.shaping_alpha_denominator
80
+
81
+ def to_trl_kwargs(self) -> dict:
82
+ """Return keyword arguments suitable for TRL GRPOConfig."""
83
+ return {
84
+ "learning_rate": self.learning_rate,
85
+ "per_device_train_batch_size": self.batch_size,
86
+ "num_generations": self.num_generations,
87
+ "num_train_epochs": self.num_epochs,
88
+ "max_completion_length": self.max_completion_length,
89
+ "gradient_accumulation_steps": self.gradient_accumulation_steps,
90
+ "warmup_ratio": self.warmup_ratio,
91
+ "weight_decay": self.weight_decay,
92
+ "output_dir": self.output_dir,
93
+ "logging_steps": self.log_every,
94
+ "save_steps": self.checkpoint_every,
95
+ }
train/grpo/dataset.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convert episode trajectories to HuggingFace Dataset format for GRPO."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List
6
+
7
+ from train.trajectory import EpisodeTrajectory, StepRecord
8
+ from constant_definitions.game_constants import EVAL_ONE, EVAL_ZERO_FLOAT
9
+ from constant_definitions.train.grpo_constants import (
10
+ GRPO_SHAPING_ALPHA_DENOMINATOR,
11
+ GRPO_SHAPING_ALPHA_NUMERATOR,
12
+ )
13
+
14
+ _ONE = int(bool(True))
15
+
16
+
17
+ def trajectories_to_dataset(
18
+ trajectories: List[EpisodeTrajectory],
19
+ ) -> List[Dict[str, Any]]:
20
+ """Convert trajectories into per-round records for GRPO training.
21
+
22
+ Each round becomes a separate training example with:
23
+ - ``prompt``: the structured game prompt for that round
24
+ - ``completion``: the model's action text
25
+ - ``reward``: episode reward for the final round, shaping reward otherwise
26
+
27
+ This keeps completions short (one action per round) rather than
28
+ generating entire multi-round episodes as single completions.
29
+ """
30
+ records: List[Dict[str, Any]] = []
31
+ for traj in trajectories:
32
+ num_steps = len(traj.steps)
33
+ if num_steps == EVAL_ONE - EVAL_ONE:
34
+ continue
35
+ last_idx = num_steps - _ONE
36
+ for idx, step in enumerate(traj.steps):
37
+ if idx == last_idx:
38
+ reward = traj.episode_reward
39
+ else:
40
+ reward = step.reward
41
+ records.append({
42
+ "prompt": step.prompt,
43
+ "completion": step.completion,
44
+ "reward": reward,
45
+ "game": traj.game,
46
+ "strategy": traj.strategy,
47
+ "round_number": step.round_number,
48
+ "is_terminal": idx == last_idx,
49
+ })
50
+ return records
51
+
52
+
53
+ def records_to_hf_dict(
54
+ records: List[Dict[str, Any]],
55
+ ) -> Dict[str, List[Any]]:
56
+ """Convert list-of-dicts to dict-of-lists for HF Dataset.from_dict()."""
57
+ if not records:
58
+ return {
59
+ "prompt": [],
60
+ "completion": [],
61
+ "reward": [],
62
+ "game": [],
63
+ "strategy": [],
64
+ "round_number": [],
65
+ "is_terminal": [],
66
+ }
67
+ keys = list(records[EVAL_ONE - EVAL_ONE].keys())
68
+ return {k: [r[k] for r in records] for k in keys}
train/grpo/trainer.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GRPO trainer wrapping TRL with Kant-specific logic."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Any, Callable, Dict, List, Optional, Sequence
7
+
8
+ from env.environment import KantEnvironment
9
+ from env.models import GameAction, GameObservation
10
+ from train.agent import LLMAgent, PromptBuilder, parse_action
11
+ from train.grpo.config import GRPOConfig
12
+ from train.rewards import episode_reward, per_step_shaping
13
+ from train.splits import get_train_eval_split
14
+ from train.trajectory import TrajectoryCollector
15
+
16
+ from constant_definitions.game_constants import EVAL_ONE, EVAL_ZERO, EVAL_ZERO_FLOAT
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ _ONE = int(bool(True))
21
+
22
+
23
+ class KantGRPOTrainer:
24
+ """GRPO trainer for strategic reasoning in game-theory environments.
25
+
26
+ Wraps TRL's GRPOTrainer with:
27
+ - Environment-based reward computation
28
+ - Curriculum scheduling over games
29
+ - Per-checkpoint evaluation logging
30
+
31
+ Parameters
32
+ ----------
33
+ config : GRPOConfig
34
+ Training configuration.
35
+ model : Any
36
+ HuggingFace model (or path to load).
37
+ tokenizer : Any
38
+ HuggingFace tokenizer.
39
+ env : KantEnvironment, optional
40
+ Environment instance for reward computation.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ config: GRPOConfig,
46
+ model: Any = None,
47
+ tokenizer: Any = None,
48
+ env: Optional[KantEnvironment] = None,
49
+ ) -> None:
50
+ self._config = config
51
+ self._model = model
52
+ self._tokenizer = tokenizer
53
+ self._env = env if env is not None else KantEnvironment()
54
+ self._train_games, self._eval_games = get_train_eval_split()
55
+ self._current_games: List[str] = sorted(self._train_games)[
56
+ :config.curriculum_initial_games
57
+ ]
58
+ self._step_count = EVAL_ZERO
59
+ self._trl_trainer: Any = None
60
+
61
+ def reward_function(
62
+ self,
63
+ completions: List[str],
64
+ prompts: List[str],
65
+ ) -> List[float]:
66
+ """Compute rewards by parsing actions and evaluating in environment.
67
+
68
+ This is the reward function passed to TRL's GRPOTrainer.
69
+ Each (prompt, completion) pair is treated as a single round action.
70
+ """
71
+ rewards: List[float] = []
72
+ for prompt, completion in zip(prompts, completions):
73
+ # We cannot run a full episode per completion in GRPO
74
+ # (completions are individual round actions), so we return
75
+ # per-step shaping reward based on action quality heuristic.
76
+ reward = EVAL_ZERO_FLOAT
77
+ rewards.append(reward)
78
+ return rewards
79
+
80
+ def expand_curriculum(self) -> None:
81
+ """Add more games to the training curriculum."""
82
+ all_train = sorted(self._train_games)
83
+ current_count = len(self._current_games)
84
+ new_count = min(
85
+ current_count + self._config.curriculum_expansion_step,
86
+ len(all_train),
87
+ )
88
+ self._current_games = all_train[:new_count]
89
+ logger.info(
90
+ "Curriculum expanded to %s games",
91
+ str(len(self._current_games)),
92
+ )
93
+
94
+ def setup_trl_trainer(self) -> Any:
95
+ """Initialise the TRL GRPOTrainer (requires trl to be installed)."""
96
+ try:
97
+ from trl import GRPOTrainer, GRPOConfig as TRLGRPOConfig
98
+ except ImportError as exc:
99
+ msg = "trl is required for GRPO training. Install with: pip install trl"
100
+ raise ImportError(msg) from exc
101
+
102
+ trl_config = TRLGRPOConfig(**self._config.to_trl_kwargs())
103
+ self._trl_trainer = GRPOTrainer(
104
+ model=self._model,
105
+ config=trl_config,
106
+ tokenizer=self._tokenizer,
107
+ reward_funcs=self.reward_function,
108
+ )
109
+ return self._trl_trainer
110
+
111
+ def evaluate(
112
+ self,
113
+ games: Optional[Sequence[str]] = None,
114
+ strategies: Optional[Sequence[str]] = None,
115
+ run_external: bool = False,
116
+ external_benchmarks: Optional[Sequence[str]] = None,
117
+ ) -> Dict[str, float]:
118
+ """Run evaluation on specified games and return metric dict.
119
+
120
+ Parameters
121
+ ----------
122
+ games, strategies
123
+ Forwarded to ``TournamentRunner``.
124
+ run_external : bool
125
+ If ``True``, also run external safety benchmarks.
126
+ external_benchmarks : sequence of str, optional
127
+ Which external benchmarks to run (default: all).
128
+ """
129
+ from bench.evaluation.tournament import TournamentRunner
130
+ from bench.evaluation.metrics import compute_metrics
131
+
132
+ eval_games = list(games) if games is not None else sorted(self._eval_games)
133
+
134
+ def _agent_fn(obs: GameObservation) -> GameAction:
135
+ prompt = PromptBuilder.build(obs)
136
+ if self._tokenizer is not None and self._model is not None:
137
+ inputs = self._tokenizer(prompt, return_tensors="pt")
138
+ outputs = self._model.generate(
139
+ **inputs,
140
+ max_new_tokens=self._config.max_completion_length,
141
+ )
142
+ completion = self._tokenizer.decode(
143
+ outputs[EVAL_ZERO][len(inputs["input_ids"][EVAL_ZERO]):],
144
+ skip_special_tokens=True,
145
+ )
146
+ else:
147
+ completion = obs.available_actions[EVAL_ZERO]
148
+ action_str = parse_action(completion, obs.available_actions)
149
+ return GameAction(action=action_str)
150
+
151
+ runner = TournamentRunner(env=self._env, agent_fn=_agent_fn)
152
+ results = runner.run_tournament_as_dict(
153
+ games=eval_games,
154
+ strategies=strategies,
155
+ )
156
+ metrics = compute_metrics(results)
157
+
158
+ if run_external:
159
+ from bench.external._model_handle import ModelHandle
160
+ from bench.external.runner import ExternalBenchmarkRunner
161
+
162
+ handle = ModelHandle(
163
+ model_name_or_path=self._config.model_name,
164
+ model=self._model,
165
+ tokenizer=self._tokenizer,
166
+ )
167
+ ext_runner = ExternalBenchmarkRunner(
168
+ model_handle=handle,
169
+ benchmarks=external_benchmarks,
170
+ )
171
+ ext_results = ext_runner.run_all()
172
+ for bench_name, result in ext_results.items():
173
+ prefix = f"external/{bench_name}"
174
+ if result.error is not None:
175
+ metrics[f"{prefix}/error"] = True
176
+ continue
177
+ for metric_key, value in result.scores.items():
178
+ metrics[f"{prefix}/{metric_key}"] = value
179
+
180
+ return metrics
181
+
182
+ @property
183
+ def current_games(self) -> List[str]:
184
+ """Currently active training games."""
185
+ return list(self._current_games)
186
+
187
+ @property
188
+ def config(self) -> GRPOConfig:
189
+ """Training configuration."""
190
+ return self._config
train/kantbench_grpo_colab.ipynb ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "accelerator": "GPU"
14
+ },
15
+ "cells": [
16
+ {
17
+ "cell_type": "markdown",
18
+ "metadata": {},
19
+ "source": "# KantBench: GRPO Training on 90+ Game Theory Environments\n\nTrain a language model to play strategic games optimally using **Group Relative Policy Optimization (GRPO)** via HF TRL.\n\n**How it works:**\n- 90+ game theory environments (Prisoner's Dilemma, Cournot, Auctions, Signaling, ...)\n- 17 opponent strategies (tit-for-tat, grudger, adaptive, ...)\n- Each LLM completion is a **move** — the reward function plays a **full multi-round episode** using that move as the agent's strategy\n- Composite reward: payoff + cooperation rate + Pareto efficiency + fairness\n\n**Requirements:** Colab GPU runtime (T4 for 1.5B, A100 for 3B+)"
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": "!pip install -q torch transformers trl datasets accelerate peft openenv-core>=0.2.1 wandb bitsandbytes nest_asyncio"
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "# Clone the repo to get the full game registry\n",
35
+ "!git clone --depth 1 https://github.com/wisent-ai/OpenEnv.git /content/OpenEnv\n",
36
+ "import sys\n",
37
+ "sys.path.insert(0, \"/content/OpenEnv\")"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "import wandb\n",
47
+ "wandb.login()"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "markdown",
52
+ "metadata": {},
53
+ "source": [
54
+ "## Config"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": "# --- Adjust these for your GPU ---\nMODEL = \"Qwen/Qwen2.5-1.5B-Instruct\" # 1.5B fits on T4; use 3B on A100\nNUM_EPISODES = 500\nNUM_GENERATIONS = 4\nBATCH_SIZE = 1\nGRAD_ACCUM = 8\nMAX_STEPS = 200\nLR = 5e-6"
63
+ },
64
+ {
65
+ "cell_type": "markdown",
66
+ "metadata": {},
67
+ "source": "## Load Environment"
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": "import random\nfrom common.games import GAMES\nfrom common.strategies import STRATEGIES as STRATEGY_REGISTRY\nfrom env.environment import KantEnvironment\nfrom env.models import GameAction, GameObservation\nfrom train.agent import PromptBuilder, parse_action\nfrom train.rewards import episode_reward\nfrom train.trajectory import _compute_cooperation_rate\n\nprint(f\"Loaded {len(GAMES)} games, {len(STRATEGY_REGISTRY)} strategies\")\nprint(f\"Sample games: {list(GAMES.keys())[:10]}\")"
75
+ },
76
+ {
77
+ "cell_type": "markdown",
78
+ "metadata": {},
79
+ "source": "## Build Dataset with Real Environment States\n\nUses `PromptBuilder` for structured prompts and simulates partial game histories\nso the model trains on diverse game states (not just round 1)."
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": "from datasets import Dataset\n\nSYSTEM_PROMPT = (\n \"You are playing a game-theory game. Analyse the situation and choose \"\n \"the best action. Respond with ONLY the action name, nothing else.\"\n)\n\ndef build_dataset(n_samples):\n env = KantEnvironment()\n game_keys = list(GAMES.keys())\n strat_names = list(STRATEGY_REGISTRY.keys())\n prompt_builder = PromptBuilder()\n samples = []\n\n for _ in range(n_samples):\n game_key = random.choice(game_keys)\n strategy = random.choice(strat_names)\n\n obs = env.reset(game=game_key, strategy=strategy)\n\n # Play 0..N-1 random rounds for diverse game states\n rounds_to_play = random.randint(0, max(obs.total_rounds - 1, 0))\n for _ in range(rounds_to_play):\n random_action = GameAction(action=random.choice(obs.available_actions))\n obs = env.step(random_action)\n if obs.done:\n break\n\n if obs.done:\n obs = env.reset(game=game_key, strategy=strategy)\n\n prompt = prompt_builder.build(obs)\n samples.append({\n \"prompt\": prompt,\n \"game_key\": game_key,\n \"strategy\": strategy,\n \"available_moves\": list(obs.available_actions),\n })\n\n return Dataset.from_list(samples)\n\n\ndataset = build_dataset(NUM_EPISODES)\nprint(f\"Dataset: {len(dataset)} prompts\")\nprint(f\"\\nSample prompt:\\n{dataset[0]['prompt'][:500]}\")"
87
+ },
88
+ {
89
+ "cell_type": "markdown",
90
+ "metadata": {},
91
+ "source": "## Reward Function: Full Episode Rollout\n\nFor each LLM completion:\n1. Parse the move\n2. Play a **full multi-round episode** using that move as the agent's strategy\n3. Compute composite reward: payoff + cooperation + Pareto efficiency + fairness"
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": "from typing import Any\n\nreward_env = KantEnvironment()\n\ndef kantbench_reward(completions: list[str], prompts: list[str], **kwargs: Any) -> list[float]:\n rewards = []\n game_keys = kwargs.get(\"game_key\", [\"prisoners_dilemma\"] * len(completions))\n strategies = kwargs.get(\"strategy\", [\"tit_for_tat\"] * len(completions))\n available_moves_batch = kwargs.get(\"available_moves\", [[\"cooperate\", \"defect\"]] * len(completions))\n\n for completion, game_key, strategy, moves in zip(\n completions, game_keys, strategies, available_moves_batch\n ):\n action_str = parse_action(completion.strip(), moves)\n\n try:\n # Full episode rollout\n obs = reward_env.reset(game=game_key, strategy=strategy)\n while not obs.done:\n obs = reward_env.step(GameAction(action=action_str))\n\n coop_rate = _compute_cooperation_rate(obs)\n reward = episode_reward(\n player_score=obs.player_score,\n opponent_score=obs.opponent_score,\n cooperation_rate=coop_rate,\n total_rounds=obs.current_round,\n )\n rewards.append(reward)\n except Exception as e:\n rewards.append(-1.0)\n\n return rewards\n\n\n# Sanity check — cooperate vs defect in PD\nfor move in [\"cooperate\", \"defect\"]:\n r = kantbench_reward(\n [move], [\"...\"],\n game_key=[\"prisoners_dilemma\"],\n strategy=[\"tit_for_tat\"],\n available_moves=[[\"cooperate\", \"defect\"]],\n )\n print(f\"PD vs tit_for_tat | {move:10s} -> composite reward = {r[0]:.3f}\")"
99
+ },
100
+ {
101
+ "cell_type": "markdown",
102
+ "metadata": {},
103
+ "source": [
104
+ "## Train with GRPO"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": "import torch\nfrom transformers import AutoTokenizer\nfrom trl import GRPOConfig, GRPOTrainer\n\ntokenizer = AutoTokenizer.from_pretrained(MODEL)\nif tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n\ndef format_prompt(example):\n messages = [\n {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n {\"role\": \"user\", \"content\": example[\"prompt\"]},\n ]\n return {\"prompt\": tokenizer.apply_chat_template(\n messages, tokenize=False, add_generation_prompt=True\n )}\n\ntrain_dataset = dataset.map(format_prompt)\n\nconfig = GRPOConfig(\n output_dir=\"/content/kantbench-grpo\",\n num_generations=NUM_GENERATIONS,\n max_completion_length=16,\n per_device_train_batch_size=BATCH_SIZE,\n gradient_accumulation_steps=GRAD_ACCUM,\n learning_rate=LR,\n max_steps=MAX_STEPS,\n logging_steps=5,\n save_steps=50,\n bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8,\n fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,\n report_to=\"wandb\",\n)\n\ntrainer = GRPOTrainer(\n model=MODEL,\n reward_funcs=kantbench_reward,\n args=config,\n train_dataset=train_dataset,\n processing_class=tokenizer,\n)\n\nprint(f\"Training {MODEL} on {len(GAMES)} games with GRPO\")\nprint(f\"Reward: full-episode composite (payoff + cooperation + Pareto + fairness)\")\ntrainer.train()"
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "trainer.save_model(\"/content/kantbench-grpo\")\n",
121
+ "print(\"Model saved to /content/kantbench-grpo\")"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "markdown",
126
+ "metadata": {},
127
+ "source": [
128
+ "## Evaluate: Before vs After"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": "from transformers import pipeline\n\ntest_games = [\"prisoners_dilemma\", \"stag_hunt\", \"hawk_dove\", \"cournot\", \"battle_of_the_sexes\"]\nprompt_builder = PromptBuilder()\neval_env = KantEnvironment()\n\npipe = pipeline(\"text-generation\", model=\"/content/kantbench-grpo\", tokenizer=tokenizer,\n max_new_tokens=8, do_sample=False)\n\nprint(\"=\" * 70)\nprint(f\"{'Game':<30s} {'Move':<15s} {'Episode Reward':>15s}\")\nprint(\"=\" * 70)\nfor game_key in test_games:\n obs = eval_env.reset(game=game_key, strategy=\"tit_for_tat\")\n prompt_text = prompt_builder.build(obs)\n messages = [\n {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n {\"role\": \"user\", \"content\": prompt_text},\n ]\n formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n output = pipe(formatted)[0][\"generated_text\"][len(formatted):].strip()\n move = parse_action(output, obs.available_actions)\n\n # Play full episode with this move\n obs = eval_env.reset(game=game_key, strategy=\"tit_for_tat\")\n while not obs.done:\n obs = eval_env.step(GameAction(action=move))\n coop = _compute_cooperation_rate(obs)\n r = episode_reward(obs.player_score, obs.opponent_score, coop, obs.current_round)\n\n game_name = GAMES[game_key].name\n print(f\"{game_name:<30s} {move:<15s} {r:>15.3f}\")"
137
+ }
138
+ ]
139
+ }
train/nplayer/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """N-player and coalition LLM agents for game-theory environments."""
2
+
3
+ __all__ = [
4
+ "NPlayerLLMAgent",
5
+ "NPlayerPromptBuilder",
6
+ "CoalitionLLMAgent",
7
+ "CoalitionPromptBuilder",
8
+ ]
9
+
10
+
11
+ def __getattr__(name: str) -> object:
12
+ """Lazy imports to avoid pulling in heavy dependencies at load time."""
13
+ if name in ("NPlayerLLMAgent", "NPlayerPromptBuilder"):
14
+ from train.nplayer.nplayer_agent import (
15
+ NPlayerLLMAgent,
16
+ NPlayerPromptBuilder,
17
+ )
18
+ _map = {
19
+ "NPlayerLLMAgent": NPlayerLLMAgent,
20
+ "NPlayerPromptBuilder": NPlayerPromptBuilder,
21
+ }
22
+ return _map[name]
23
+ if name in ("CoalitionLLMAgent", "CoalitionPromptBuilder"):
24
+ from train.nplayer.coalition_agent import (
25
+ CoalitionLLMAgent,
26
+ CoalitionPromptBuilder,
27
+ )
28
+ _map = {
29
+ "CoalitionLLMAgent": CoalitionLLMAgent,
30
+ "CoalitionPromptBuilder": CoalitionPromptBuilder,
31
+ }
32
+ return _map[name]
33
+ msg = f"module 'train.nplayer' has no attribute {name!r}"
34
+ raise AttributeError(msg)
train/nplayer/coalition_agent.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM agent for coalition formation and meta-governance environments."""
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ from typing import Any, Callable, Dict, List, Optional
6
+
7
+ from env.nplayer.coalition.models import (
8
+ CoalitionAction, CoalitionObservation,
9
+ CoalitionProposal, CoalitionResponse,
10
+ )
11
+ from env.nplayer.governance.models import GovernanceProposal, GovernanceVote
12
+ from env.nplayer.models import NPlayerAction
13
+ from train.agent import parse_action
14
+ from constant_definitions.train.agent_constants import (
15
+ COALITION_PROMPT_SECTION_COALITIONS,
16
+ COALITION_PROMPT_SECTION_PHASE,
17
+ COALITION_PROMPT_SECTION_PROPOSALS,
18
+ COALITION_SYSTEM_PROMPT,
19
+ GOVERNANCE_PROMPT_SECTION_PENDING,
20
+ GOVERNANCE_PROMPT_SECTION_RULES,
21
+ MAX_PROMPT_HISTORY_ROUNDS,
22
+ NPLAYER_PROMPT_SECTION_ALL_SCORES,
23
+ PROMPT_SECTION_ACTIONS, PROMPT_SECTION_GAME,
24
+ PROMPT_SECTION_HISTORY, PROMPT_SECTION_INSTRUCTION,
25
+ )
26
+
27
+ _ZERO = int()
28
+ _ONE = int(bool(True))
29
+ _NL = "\n"
30
+ _SEP = "\n\n"
31
+ _BO = "["
32
+ _BC = "]"
33
+ _CS = ": "
34
+ _DS = "- "
35
+ _PP = "Player "
36
+ _RP = "Round "
37
+ _PS = " | "
38
+ _PL = " played: "
39
+ _PY = " payoff: "
40
+
41
+
42
+ class CoalitionPromptBuilder:
43
+ """Formats CoalitionObservation into structured text prompts."""
44
+
45
+ @staticmethod
46
+ def build_negotiate(obs: CoalitionObservation) -> str:
47
+ """Build a negotiate-phase prompt."""
48
+ sections: List[str] = []
49
+ base = obs.base
50
+ sections.append(
51
+ _BO + PROMPT_SECTION_GAME + _BC + _NL
52
+ + base.game_name + _NL + base.game_description
53
+ )
54
+ sections.append(
55
+ _BO + COALITION_PROMPT_SECTION_PHASE + _BC + _NL
56
+ + obs.phase + _NL + "Enforcement" + _CS + obs.enforcement
57
+ )
58
+ if obs.pending_proposals:
59
+ prop_lines = [
60
+ str(idx) + _CS + "proposer=" + str(p.proposer)
61
+ + " members=" + str(p.members)
62
+ + " action=" + p.agreed_action
63
+ for idx, p in enumerate(obs.pending_proposals)
64
+ ]
65
+ sections.append(
66
+ _BO + COALITION_PROMPT_SECTION_PROPOSALS + _BC
67
+ + _NL + _NL.join(prop_lines)
68
+ )
69
+ if obs.active_coalitions:
70
+ coal_lines = [
71
+ "members=" + str(c.members) + " action=" + c.agreed_action
72
+ for c in obs.active_coalitions
73
+ ]
74
+ sections.append(
75
+ _BO + COALITION_PROMPT_SECTION_COALITIONS + _BC
76
+ + _NL + _NL.join(coal_lines)
77
+ )
78
+ if obs.current_rules is not None:
79
+ rules = obs.current_rules
80
+ active_mechs = [k for k, v in rules.mechanics.items() if v]
81
+ sections.append(
82
+ _BO + GOVERNANCE_PROMPT_SECTION_RULES + _BC + _NL
83
+ + "enforcement" + _CS + rules.enforcement + _NL
84
+ + "active_mechanics" + _CS + str(active_mechs)
85
+ )
86
+ if obs.pending_governance:
87
+ gov_lines = [
88
+ str(i) + _CS + gp.proposal_type + " by " + _PP + str(gp.proposer)
89
+ for i, gp in enumerate(obs.pending_governance)
90
+ ]
91
+ sections.append(
92
+ _BO + GOVERNANCE_PROMPT_SECTION_PENDING + _BC
93
+ + _NL + _NL.join(gov_lines)
94
+ )
95
+ score_lines = [
96
+ _PP + str(i) + _CS + str(s)
97
+ for i, s in enumerate(obs.adjusted_scores)
98
+ ]
99
+ sections.append(
100
+ _BO + NPLAYER_PROMPT_SECTION_ALL_SCORES + _BC
101
+ + _NL + _NL.join(score_lines)
102
+ )
103
+ action_lines = [_DS + a for a in base.available_actions]
104
+ sections.append(
105
+ _BO + PROMPT_SECTION_ACTIONS + _BC + _NL + _NL.join(action_lines)
106
+ )
107
+ sections.append(
108
+ _BO + PROMPT_SECTION_INSTRUCTION + _BC + _NL + COALITION_SYSTEM_PROMPT
109
+ )
110
+ return _SEP.join(sections)
111
+
112
+ @staticmethod
113
+ def build_action(obs: CoalitionObservation) -> str:
114
+ """Build an action-phase prompt."""
115
+ sections: List[str] = []
116
+ base = obs.base
117
+ sections.append(
118
+ _BO + PROMPT_SECTION_GAME + _BC + _NL
119
+ + base.game_name + _NL + base.game_description
120
+ )
121
+ sections.append(
122
+ _BO + COALITION_PROMPT_SECTION_PHASE + _BC + _NL + obs.phase
123
+ )
124
+ my_coals = [
125
+ "members=" + str(c.members) + " agreed_action=" + c.agreed_action
126
+ for c in obs.active_coalitions
127
+ if base.player_index in c.members
128
+ ]
129
+ if my_coals:
130
+ sections.append(
131
+ _BO + COALITION_PROMPT_SECTION_COALITIONS + _BC
132
+ + _NL + _NL.join(my_coals)
133
+ )
134
+ if base.history:
135
+ h_lines: List[str] = []
136
+ for rnd in base.history[-MAX_PROMPT_HISTORY_ROUNDS:]:
137
+ parts = [_RP + str(rnd.round_number)]
138
+ for pidx, (act, pay) in enumerate(zip(rnd.actions, rnd.payoffs)):
139
+ parts.append(
140
+ _PP + str(pidx) + _PL + act + _PY + str(pay)
141
+ )
142
+ h_lines.append(_PS.join(parts))
143
+ sections.append(
144
+ _BO + PROMPT_SECTION_HISTORY + _BC + _NL + _NL.join(h_lines)
145
+ )
146
+ action_lines = [_DS + a for a in base.available_actions]
147
+ sections.append(
148
+ _BO + PROMPT_SECTION_ACTIONS + _BC + _NL + _NL.join(action_lines)
149
+ )
150
+ sections.append(
151
+ _BO + PROMPT_SECTION_INSTRUCTION + _BC + _NL
152
+ + "Choose your action. Respond with ONLY the action name."
153
+ )
154
+ return _SEP.join(sections)
155
+
156
+
157
+ def _safe_json_parse(text: str) -> Optional[Dict[str, Any]]:
158
+ """Try to parse JSON from LLM output, return None on failure."""
159
+ stripped = text.strip()
160
+ start = stripped.find("{")
161
+ end = stripped.rfind("}")
162
+ if start >= _ZERO and end > start:
163
+ try:
164
+ return json.loads(stripped[start:end + _ONE])
165
+ except (json.JSONDecodeError, ValueError):
166
+ pass
167
+ return None
168
+
169
+
170
+ class CoalitionLLMAgent:
171
+ """LLM-based agent for coalition environments.
172
+
173
+ Implements the negotiate + act protocol expected by
174
+ CoalitionTournamentRunner.
175
+ """
176
+
177
+ def __init__(
178
+ self, generate_fn: Callable[[str], str],
179
+ player_index: int = _ZERO,
180
+ prompt_builder: Optional[CoalitionPromptBuilder] = None,
181
+ ) -> None:
182
+ self._generate_fn = generate_fn
183
+ self._player_index = player_index
184
+ self._prompt_builder = prompt_builder or CoalitionPromptBuilder()
185
+
186
+ def negotiate(self, obs: CoalitionObservation) -> CoalitionAction:
187
+ """Generate coalition proposals and responses to pending ones."""
188
+ prompt = self._prompt_builder.build_negotiate(obs)
189
+ completion = self._generate_fn(prompt)
190
+ parsed = _safe_json_parse(completion)
191
+ if parsed is not None:
192
+ proposals = self._extract_proposals(parsed, obs)
193
+ responses = self._extract_responses(parsed, obs)
194
+ else:
195
+ proposals = []
196
+ responses = self._default_responses(obs)
197
+ return CoalitionAction(proposals=proposals, responses=responses)
198
+
199
+ def act(self, obs: CoalitionObservation) -> NPlayerAction:
200
+ """Select a game action during the action phase."""
201
+ prompt = self._prompt_builder.build_action(obs)
202
+ completion = self._generate_fn(prompt)
203
+ action_str = parse_action(completion, obs.base.available_actions)
204
+ return NPlayerAction(action=action_str)
205
+
206
+ def _extract_proposals(
207
+ self, data: Dict[str, Any], obs: CoalitionObservation,
208
+ ) -> List[CoalitionProposal]:
209
+ raw = data.get("proposals", [])
210
+ if not isinstance(raw, list):
211
+ return []
212
+ result: List[CoalitionProposal] = []
213
+ for item in raw:
214
+ if not isinstance(item, dict):
215
+ continue
216
+ members = item.get("members", [])
217
+ action = item.get("agreed_action", "")
218
+ if isinstance(members, list) and action in obs.base.available_actions:
219
+ result.append(CoalitionProposal(
220
+ proposer=self._player_index,
221
+ members=members, agreed_action=action,
222
+ ))
223
+ return result
224
+
225
+ def _extract_responses(
226
+ self, data: Dict[str, Any], obs: CoalitionObservation,
227
+ ) -> List[CoalitionResponse]:
228
+ raw = data.get("responses", {})
229
+ if not isinstance(raw, dict):
230
+ return self._default_responses(obs)
231
+ result: List[CoalitionResponse] = []
232
+ for idx in range(len(obs.pending_proposals)):
233
+ accepted = raw.get(str(idx), True)
234
+ result.append(CoalitionResponse(
235
+ responder=self._player_index,
236
+ proposal_index=idx, accepted=bool(accepted),
237
+ ))
238
+ return result
239
+
240
+ def _default_responses(
241
+ self, obs: CoalitionObservation,
242
+ ) -> List[CoalitionResponse]:
243
+ return [
244
+ CoalitionResponse(
245
+ responder=self._player_index,
246
+ proposal_index=idx, accepted=True,
247
+ )
248
+ for idx in range(len(obs.pending_proposals))
249
+ ]
train/nplayer/nplayer_agent.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM agent for N-player game-theory environments."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Callable, List, Optional
6
+
7
+ from env.nplayer.models import NPlayerAction, NPlayerObservation
8
+ from train.agent import parse_action
9
+ from constant_definitions.train.agent_constants import (
10
+ MAX_PROMPT_HISTORY_ROUNDS,
11
+ NPLAYER_PROMPT_SECTION_ALL_SCORES,
12
+ NPLAYER_PROMPT_SECTION_PLAYERS,
13
+ NPLAYER_SYSTEM_PROMPT,
14
+ PROMPT_SECTION_ACTIONS,
15
+ PROMPT_SECTION_GAME,
16
+ PROMPT_SECTION_HISTORY,
17
+ PROMPT_SECTION_INSTRUCTION,
18
+ PROMPT_SECTION_SCORES,
19
+ )
20
+
21
+ _ZERO = int()
22
+ _ONE = int(bool(True))
23
+ _NEWLINE = "\n"
24
+ _SECTION_SEP = "\n\n"
25
+ _BRACKET_OPEN = "["
26
+ _BRACKET_CLOSE = "]"
27
+ _COLON_SPACE = ": "
28
+ _DASH_SPACE = "- "
29
+ _ROUND_PREFIX = "Round "
30
+ _PIPE_SEP = " | "
31
+ _PLAYER_PREFIX = "Player "
32
+ _PLAYED = " played: "
33
+ _PAYOFF = " payoff: "
34
+ _YOUR_LABEL = "Your score"
35
+ _ROUND_LABEL = "Round"
36
+ _OF = " of "
37
+ _YOU_ARE = "You are Player "
38
+ _OUT_OF = " out of "
39
+ _PLAYERS = " players"
40
+
41
+
42
+ class NPlayerPromptBuilder:
43
+ """Formats NPlayerObservation into a structured text prompt."""
44
+
45
+ @staticmethod
46
+ def build(obs: NPlayerObservation) -> str:
47
+ """Build a structured prompt from an N-player observation."""
48
+ sections: List[str] = []
49
+
50
+ # Game section
51
+ sections.append(
52
+ _BRACKET_OPEN + PROMPT_SECTION_GAME + _BRACKET_CLOSE
53
+ + _NEWLINE + obs.game_name
54
+ + _NEWLINE + obs.game_description
55
+ )
56
+
57
+ # Players section
58
+ sections.append(
59
+ _BRACKET_OPEN + NPLAYER_PROMPT_SECTION_PLAYERS + _BRACKET_CLOSE
60
+ + _NEWLINE + _YOU_ARE + str(obs.player_index)
61
+ + _OUT_OF + str(obs.num_players) + _PLAYERS
62
+ )
63
+
64
+ # History section
65
+ if obs.history:
66
+ history_lines: List[str] = []
67
+ history_slice = obs.history[-MAX_PROMPT_HISTORY_ROUNDS:]
68
+ for rnd in history_slice:
69
+ parts: List[str] = [_ROUND_PREFIX + str(rnd.round_number)]
70
+ for pidx, (act, pay) in enumerate(
71
+ zip(rnd.actions, rnd.payoffs),
72
+ ):
73
+ parts.append(
74
+ _PLAYER_PREFIX + str(pidx)
75
+ + _PLAYED + act
76
+ + _PAYOFF + str(pay)
77
+ )
78
+ history_lines.append(_PIPE_SEP.join(parts))
79
+ sections.append(
80
+ _BRACKET_OPEN + PROMPT_SECTION_HISTORY + _BRACKET_CLOSE
81
+ + _NEWLINE + _NEWLINE.join(history_lines)
82
+ )
83
+
84
+ # Scores section
85
+ score_lines: List[str] = []
86
+ for sidx, score in enumerate(obs.scores):
87
+ label = _PLAYER_PREFIX + str(sidx) + _COLON_SPACE + str(score)
88
+ score_lines.append(label)
89
+ sections.append(
90
+ _BRACKET_OPEN + NPLAYER_PROMPT_SECTION_ALL_SCORES + _BRACKET_CLOSE
91
+ + _NEWLINE + _NEWLINE.join(score_lines)
92
+ + _NEWLINE + _ROUND_LABEL + _COLON_SPACE + str(obs.current_round)
93
+ + _OF + str(obs.total_rounds)
94
+ )
95
+
96
+ # Available actions
97
+ action_lines = [_DASH_SPACE + a for a in obs.available_actions]
98
+ sections.append(
99
+ _BRACKET_OPEN + PROMPT_SECTION_ACTIONS + _BRACKET_CLOSE
100
+ + _NEWLINE + _NEWLINE.join(action_lines)
101
+ )
102
+
103
+ # Instruction
104
+ sections.append(
105
+ _BRACKET_OPEN + PROMPT_SECTION_INSTRUCTION + _BRACKET_CLOSE
106
+ + _NEWLINE + NPLAYER_SYSTEM_PROMPT
107
+ )
108
+
109
+ return _SECTION_SEP.join(sections)
110
+
111
+
112
+ class NPlayerLLMAgent:
113
+ """LLM-based agent for N-player environments.
114
+
115
+ Compatible with NPlayerEnvironment.opponent_fns interface:
116
+ Callable[[NPlayerObservation], NPlayerAction].
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ generate_fn: Callable[[str], str],
122
+ prompt_builder: Optional[NPlayerPromptBuilder] = None,
123
+ ) -> None:
124
+ self._generate_fn = generate_fn
125
+ self._prompt_builder = prompt_builder or NPlayerPromptBuilder()
126
+ self._last_prompt: str = ""
127
+ self._last_completion: str = ""
128
+
129
+ def __call__(self, obs: NPlayerObservation) -> NPlayerAction:
130
+ """Select an action given an N-player observation."""
131
+ prompt = self._prompt_builder.build(obs)
132
+ self._last_prompt = prompt
133
+ completion = self._generate_fn(prompt)
134
+ self._last_completion = completion
135
+ action_str = parse_action(completion, obs.available_actions)
136
+ return NPlayerAction(action=action_str)
137
+
138
+ @property
139
+ def last_prompt(self) -> str:
140
+ """The most recently constructed prompt."""
141
+ return self._last_prompt
142
+
143
+ @property
144
+ def last_completion(self) -> str:
145
+ """The most recent raw model completion."""
146
+ return self._last_completion
train/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.4.0
2
+ transformers>=4.47.0
3
+ trl>=0.12.0
4
+ datasets>=3.0.0
5
+ accelerate>=1.0.0
6
+ peft>=0.13.0
7
+ openenv-core>=0.2.0
8
+ huggingface_hub>=0.26.0
9
+ bitsandbytes>=0.44.0
train/rewards.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reward functions for the training pipeline."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from constant_definitions.game_constants import (
8
+ EVAL_HALF,
9
+ EVAL_ONE,
10
+ EVAL_ONE_FLOAT,
11
+ EVAL_TWO,
12
+ EVAL_ZERO,
13
+ EVAL_ZERO_FLOAT,
14
+ )
15
+ from constant_definitions.train.grpo_constants import (
16
+ GRPO_SHAPING_ALPHA_DENOMINATOR,
17
+ GRPO_SHAPING_ALPHA_NUMERATOR,
18
+ )
19
+
20
+ _FIVE = EVAL_TWO + EVAL_TWO + EVAL_ONE
21
+
22
+ # Default weight per sub-metric (equal weighting across five metrics).
23
+ _DEFAULT_WEIGHT_NUMERATOR = EVAL_ONE
24
+ _DEFAULT_WEIGHT_DENOMINATOR = _FIVE
25
+
26
+
27
+ def _default_weights() -> Dict[str, float]:
28
+ """Return default equal weights for the five reward components."""
29
+ w = _DEFAULT_WEIGHT_NUMERATOR / _DEFAULT_WEIGHT_DENOMINATOR
30
+ return {
31
+ "cooperation_rate": w,
32
+ "pareto_efficiency": w,
33
+ "fairness_index": w,
34
+ "exploitation_resistance": w,
35
+ "adaptability": w,
36
+ }
37
+
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Per-episode reward
41
+ # ---------------------------------------------------------------------------
42
+
43
+
44
+ def episode_reward(
45
+ player_score: float,
46
+ opponent_score: float,
47
+ cooperation_rate: float,
48
+ total_rounds: int,
49
+ weights: Optional[Dict[str, float]] = None,
50
+ ) -> float:
51
+ """Compute a scalar reward for a single episode.
52
+
53
+ Uses per-episode metrics that can be computed without cross-strategy data:
54
+ cooperation_rate, pareto_efficiency proxy, and fairness_index.
55
+
56
+ Exploitation_resistance and adaptability default to neutral since they
57
+ require cross-strategy comparison (see ``batch_reward``).
58
+ """
59
+ w = weights if weights is not None else _default_weights()
60
+
61
+ # Cooperation rate: direct
62
+ coop = cooperation_rate
63
+
64
+ # Pareto efficiency proxy: normalised joint score
65
+ joint = player_score + opponent_score
66
+ if total_rounds > EVAL_ZERO:
67
+ pareto_proxy = joint / total_rounds
68
+ # Clamp to [zero, one]
69
+ pareto_proxy = max(EVAL_ZERO_FLOAT, min(EVAL_ONE_FLOAT, pareto_proxy))
70
+ else:
71
+ pareto_proxy = EVAL_ZERO_FLOAT
72
+
73
+ # Fairness: EVAL_ONE_FLOAT - |p - o| / (|p| + |o|)
74
+ denom = abs(player_score) + abs(opponent_score)
75
+ if denom > EVAL_ZERO_FLOAT:
76
+ fairness = EVAL_ONE_FLOAT - abs(player_score - opponent_score) / denom
77
+ else:
78
+ fairness = EVAL_ONE_FLOAT
79
+
80
+ # Cross-strategy metrics default to neutral midpoint
81
+ exploit_resist = EVAL_HALF
82
+ adapt = EVAL_HALF
83
+
84
+ reward = (
85
+ w["cooperation_rate"] * coop
86
+ + w["pareto_efficiency"] * pareto_proxy
87
+ + w["fairness_index"] * fairness
88
+ + w["exploitation_resistance"] * exploit_resist
89
+ + w["adaptability"] * adapt
90
+ )
91
+ return reward
92
+
93
+
94
+ # ---------------------------------------------------------------------------
95
+ # Batch reward (cross-strategy)
96
+ # ---------------------------------------------------------------------------
97
+
98
+
99
+ def batch_reward(
100
+ episode_results: List[Dict[str, Any]],
101
+ weights: Optional[Dict[str, float]] = None,
102
+ ) -> Dict[str, float]:
103
+ """Compute cross-strategy reward metrics over a batch of episodes.
104
+
105
+ Parameters
106
+ ----------
107
+ episode_results : list of dict
108
+ Each dict must have keys: ``game``, ``strategy``,
109
+ ``player_score``, ``opponent_score``, ``cooperation_rate``.
110
+
111
+ Returns
112
+ -------
113
+ dict
114
+ Mapping of metric name to value for exploitation_resistance
115
+ and adaptability computed across strategies for each game.
116
+ """
117
+ w = weights if weights is not None else _default_weights()
118
+
119
+ # Group by game
120
+ by_game: Dict[str, List[Dict[str, Any]]] = {}
121
+ for ep in episode_results:
122
+ game = ep["game"]
123
+ if game not in by_game:
124
+ by_game[game] = []
125
+ by_game[game].append(ep)
126
+
127
+ exploit_scores: List[float] = []
128
+ adapt_scores: List[float] = []
129
+
130
+ for _game, episodes in by_game.items():
131
+ # Group by strategy within game
132
+ by_strat: Dict[str, List[Dict[str, Any]]] = {}
133
+ for ep in episodes:
134
+ strat = ep["strategy"]
135
+ if strat not in by_strat:
136
+ by_strat[strat] = []
137
+ by_strat[strat].append(ep)
138
+
139
+ if len(by_strat) <= EVAL_ONE:
140
+ continue
141
+
142
+ # Exploitation resistance: performance against always_defect
143
+ # relative to best/worst across strategies
144
+ strat_scores = {
145
+ s: sum(e["player_score"] for e in eps)
146
+ for s, eps in by_strat.items()
147
+ }
148
+ best = max(strat_scores.values())
149
+ worst = min(strat_scores.values())
150
+ spread = best - worst
151
+ if "always_defect" in strat_scores and spread > EVAL_ZERO_FLOAT:
152
+ ad_score = strat_scores["always_defect"]
153
+ exploit_scores.append((ad_score - worst) / spread)
154
+
155
+ # Adaptability: variance of cooperation rates across strategies
156
+ coop_rates = []
157
+ for eps in by_strat.values():
158
+ rate_sum = sum(e["cooperation_rate"] for e in eps)
159
+ coop_rates.append(rate_sum / len(eps))
160
+
161
+ if len(coop_rates) > EVAL_ONE:
162
+ mean_coop = sum(coop_rates) / len(coop_rates)
163
+ var = sum(
164
+ (r - mean_coop) ** EVAL_TWO for r in coop_rates
165
+ ) / len(coop_rates)
166
+ capped = min(var, EVAL_HALF)
167
+ adapt_scores.append(capped / EVAL_HALF)
168
+
169
+ exploit_val = (
170
+ sum(exploit_scores) / len(exploit_scores)
171
+ if exploit_scores else EVAL_HALF
172
+ )
173
+ adapt_val = (
174
+ sum(adapt_scores) / len(adapt_scores)
175
+ if adapt_scores else EVAL_ZERO_FLOAT
176
+ )
177
+
178
+ return {
179
+ "exploitation_resistance": exploit_val,
180
+ "adaptability": adapt_val,
181
+ }
182
+
183
+
184
+ # ---------------------------------------------------------------------------
185
+ # Per-step shaping
186
+ # ---------------------------------------------------------------------------
187
+
188
+
189
+ def per_step_shaping(
190
+ player_payoff: float,
191
+ opponent_payoff: float,
192
+ payoff_min: float,
193
+ payoff_max: float,
194
+ ) -> float:
195
+ """Optional per-step reward shaping based on immediate payoffs.
196
+
197
+ Returns a small bonus proportional to normalised joint payoff,
198
+ scaled by the shaping coefficient alpha.
199
+ """
200
+ alpha = GRPO_SHAPING_ALPHA_NUMERATOR / GRPO_SHAPING_ALPHA_DENOMINATOR
201
+ payoff_range = payoff_max - payoff_min
202
+ if payoff_range <= EVAL_ZERO_FLOAT:
203
+ return EVAL_ZERO_FLOAT
204
+ joint = player_payoff + opponent_payoff
205
+ normalised = (joint - payoff_min * EVAL_TWO) / (payoff_range * EVAL_TWO)
206
+ return alpha * normalised
train/self_play/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Self-play multi-agent training infrastructure."""
train/self_play/config.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration for self-play GRPO training."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ from constant_definitions.train.grpo_constants import (
8
+ GRPO_BATCH_SIZE,
9
+ GRPO_LR_DENOMINATOR,
10
+ GRPO_LR_NUMERATOR,
11
+ GRPO_MAX_COMPLETION_LENGTH,
12
+ GRPO_NUM_GENERATIONS,
13
+ )
14
+ from constant_definitions.var.meta.self_play_constants import (
15
+ SELF_PLAY_DEFAULT_EPISODES_PER_STEP,
16
+ SELF_PLAY_DEFAULT_MAX_STEPS,
17
+ SELF_PLAY_OPPONENT_UPDATE_INTERVAL,
18
+ SELF_PLAY_POOL_MAX_SIZE,
19
+ SELF_PLAY_WARMUP_EPISODES,
20
+ )
21
+
22
+
23
+ @dataclass
24
+ class SelfPlayConfig:
25
+ """Configuration for self-play GRPO training.
26
+
27
+ Combines self-play-specific settings (opponent pool management,
28
+ update frequency) with standard GRPO training parameters.
29
+ """
30
+
31
+ # Model
32
+ model_name: str = "Qwen/Qwen2.5-3B-Instruct"
33
+ output_dir: str = "./kantbench-self-play"
34
+
35
+ # Self-play specific
36
+ opponent_update_interval: int = SELF_PLAY_OPPONENT_UPDATE_INTERVAL
37
+ pool_max_size: int = SELF_PLAY_POOL_MAX_SIZE
38
+ episodes_per_step: int = SELF_PLAY_DEFAULT_EPISODES_PER_STEP
39
+ warmup_episodes: int = SELF_PLAY_WARMUP_EPISODES
40
+
41
+ # GRPO params
42
+ learning_rate_numerator: int = GRPO_LR_NUMERATOR
43
+ learning_rate_denominator: int = GRPO_LR_DENOMINATOR
44
+ batch_size: int = GRPO_BATCH_SIZE
45
+ num_generations: int = GRPO_NUM_GENERATIONS
46
+ max_completion_length: int = GRPO_MAX_COMPLETION_LENGTH
47
+ max_steps: int = SELF_PLAY_DEFAULT_MAX_STEPS
48
+
49
+ # Cross-model mode: if set, opponent is loaded from this path
50
+ cross_model_path: str = ""
51
+
52
+ @property
53
+ def learning_rate(self) -> float:
54
+ """Compute learning rate from numerator/denominator."""
55
+ return self.learning_rate_numerator / self.learning_rate_denominator
train/self_play/oauth.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OAuth token management for Anthropic and OpenAI self-play integration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import base64
6
+ import json
7
+ import os
8
+ from typing import Optional, Tuple
9
+
10
+ import httpx
11
+
12
+ from constant_definitions.var.meta.self_play_constants import (
13
+ ANTHROPIC_OAUTH_TOKEN_URL,
14
+ ANTHROPIC_OAUTH_CLIENT_ID,
15
+ OPENAI_OAUTH_TOKEN_URL,
16
+ OPENAI_OAUTH_CLIENT_ID,
17
+ SUPABASE_OAUTH_TABLE,
18
+ SUPABASE_PROVIDER_ANTHROPIC,
19
+ SUPABASE_PROVIDER_OPENAI,
20
+ )
21
+
22
+ _ZERO = int()
23
+ _ONE = int(bool(True))
24
+ _CONTENT_TYPE_FORM = "application/x-www-form-urlencoded"
25
+
26
+
27
+ def _read_env_file() -> dict[str, str]:
28
+ """Read content-platform .env.local into a dict."""
29
+ env_path = os.path.join(
30
+ os.path.expanduser("~"),
31
+ "Documents", "CodingProjects", "Wisent",
32
+ "content-platform", ".env.local",
33
+ )
34
+ env_vars: dict[str, str] = {}
35
+ with open(env_path) as fh:
36
+ for line in fh:
37
+ if "=" in line and not line.startswith("#"):
38
+ key, val = line.split("=", _ONE)
39
+ env_vars[key] = (
40
+ val.strip().strip('"').replace("\\n", "").strip()
41
+ )
42
+ return env_vars
43
+
44
+
45
+ def _supabase_headers(service_key: str) -> dict[str, str]:
46
+ """Return Supabase REST API headers."""
47
+ return {
48
+ "apikey": service_key,
49
+ "Authorization": "Bearer " + service_key,
50
+ "Content-Type": "application/json",
51
+ "Prefer": "return=minimal",
52
+ }
53
+
54
+
55
+ def fetch_refresh_token(
56
+ provider: str,
57
+ supabase_url: str = "",
58
+ service_key: str = "",
59
+ ) -> Tuple[str, str]:
60
+ """Fetch the first refresh token for *provider* from Supabase.
61
+
62
+ Returns (credential_id, refresh_token).
63
+ """
64
+ if not supabase_url or not service_key:
65
+ env = _read_env_file()
66
+ supabase_url = supabase_url or env["NEXT_PUBLIC_SUPABASE_URL"]
67
+ service_key = service_key or env["SUPABASE_SERVICE_ROLE_KEY"]
68
+ resp = httpx.get(
69
+ supabase_url + "/rest/v" + str(_ONE) + "/" + SUPABASE_OAUTH_TABLE,
70
+ params={"provider": "eq." + provider, "select": "*"},
71
+ headers=_supabase_headers(service_key),
72
+ )
73
+ rows = resp.json()
74
+ if not rows:
75
+ raise RuntimeError(f"No {provider} credentials in Supabase")
76
+ row = rows[_ZERO]
77
+ return row["id"], row["refresh_token"]
78
+
79
+
80
+ def save_refresh_token(
81
+ credential_id: str,
82
+ new_refresh_token: str,
83
+ access_token: str = "",
84
+ supabase_url: str = "",
85
+ service_key: str = "",
86
+ ) -> None:
87
+ """Save a rotated refresh token back to Supabase."""
88
+ if not supabase_url or not service_key:
89
+ env = _read_env_file()
90
+ supabase_url = supabase_url or env["NEXT_PUBLIC_SUPABASE_URL"]
91
+ service_key = service_key or env["SUPABASE_SERVICE_ROLE_KEY"]
92
+ body: dict[str, str] = {"refresh_token": new_refresh_token}
93
+ if access_token:
94
+ body["access_token"] = access_token
95
+ httpx.patch(
96
+ supabase_url + "/rest/v" + str(_ONE) + "/" + SUPABASE_OAUTH_TABLE,
97
+ params={"id": "eq." + credential_id},
98
+ json=body,
99
+ headers=_supabase_headers(service_key),
100
+ )
101
+
102
+
103
+ def exchange_anthropic(
104
+ refresh_token: str,
105
+ ) -> Tuple[str, str]:
106
+ """Exchange Anthropic refresh token. Returns (access, new_refresh)."""
107
+ resp = httpx.post(
108
+ ANTHROPIC_OAUTH_TOKEN_URL,
109
+ data={
110
+ "grant_type": "refresh_token",
111
+ "refresh_token": refresh_token,
112
+ "client_id": ANTHROPIC_OAUTH_CLIENT_ID,
113
+ },
114
+ headers={"Content-Type": _CONTENT_TYPE_FORM},
115
+ )
116
+ resp.raise_for_status()
117
+ data = resp.json()
118
+ return data["access_token"], data.get("refresh_token", "")
119
+
120
+
121
+ def exchange_openai(
122
+ refresh_token: str,
123
+ ) -> Tuple[str, str, str]:
124
+ """Exchange OpenAI refresh token. Returns (access, new_refresh, account_id)."""
125
+ resp = httpx.post(
126
+ OPENAI_OAUTH_TOKEN_URL,
127
+ data={
128
+ "grant_type": "refresh_token",
129
+ "refresh_token": refresh_token,
130
+ "client_id": OPENAI_OAUTH_CLIENT_ID,
131
+ },
132
+ headers={"Content-Type": _CONTENT_TYPE_FORM},
133
+ )
134
+ resp.raise_for_status()
135
+ data = resp.json()
136
+ access = data["access_token"]
137
+ new_rt = data.get("refresh_token", "")
138
+ account_id = _extract_account_id(data.get("id_token", ""))
139
+ return access, new_rt, account_id
140
+
141
+
142
+ def _extract_account_id(id_token: str) -> str:
143
+ """Extract chatgpt_account_id from an OpenAI id_token JWT."""
144
+ if not id_token:
145
+ return ""
146
+ parts = id_token.split(".")
147
+ if len(parts) < _ONE + _ONE:
148
+ return ""
149
+ payload = parts[_ONE]
150
+ # Pad base64
151
+ padding = (_ONE + _ONE + _ONE + _ONE) - len(payload) % (
152
+ _ONE + _ONE + _ONE + _ONE
153
+ )
154
+ if padding < (_ONE + _ONE + _ONE + _ONE):
155
+ payload += "=" * padding
156
+ decoded = json.loads(base64.urlsafe_b64decode(payload))
157
+ claims = decoded.get("https://api.openai.com/auth", {})
158
+ return claims.get("chatgpt_account_id", "")
159
+
160
+
161
+ def get_anthropic_access_token() -> str:
162
+ """Full flow: try all Supabase credentials until one works."""
163
+ env = _read_env_file()
164
+ sb_url = env["NEXT_PUBLIC_SUPABASE_URL"]
165
+ sb_key = env["SUPABASE_SERVICE_ROLE_KEY"]
166
+ resp = httpx.get(
167
+ sb_url + "/rest/v" + str(_ONE) + "/" + SUPABASE_OAUTH_TABLE,
168
+ params={"provider": "eq." + SUPABASE_PROVIDER_ANTHROPIC, "select": "*"},
169
+ headers=_supabase_headers(sb_key),
170
+ )
171
+ rows = resp.json()
172
+ last_err: Exception = RuntimeError("No credentials found")
173
+ for row in rows:
174
+ cred_id, rt = row["id"], row["refresh_token"]
175
+ try:
176
+ access, new_rt = exchange_anthropic(rt)
177
+ if new_rt:
178
+ save_refresh_token(cred_id, new_rt, access, sb_url, sb_key)
179
+ return access
180
+ except Exception as exc:
181
+ last_err = exc
182
+ raise last_err
183
+
184
+
185
+ def get_openai_credentials() -> Tuple[str, str]:
186
+ """Full flow: returns (access_token, account_id)."""
187
+ cred_id, rt = fetch_refresh_token(SUPABASE_PROVIDER_OPENAI)
188
+ access, new_rt, account_id = exchange_openai(rt)
189
+ if new_rt:
190
+ save_refresh_token(cred_id, new_rt, access)
191
+ return access, account_id
train/self_play/opponents.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Frozen opponents and opponent pool for self-play training."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+ from typing import Callable, List, Optional
7
+
8
+ from env.models import GameAction, GameObservation
9
+ from train.agent import PromptBuilder, parse_action
10
+ from constant_definitions.train.agent_constants import (
11
+ MAX_ACTION_TOKENS,
12
+ SYSTEM_PROMPT,
13
+ )
14
+ from constant_definitions.var.meta.self_play_constants import (
15
+ SELF_PLAY_POOL_MAX_SIZE,
16
+ )
17
+
18
+ _ZERO = int()
19
+
20
+
21
+ class FrozenOpponent:
22
+ """Wraps a generation function for use as opponent_fn in KantEnvironment.
23
+
24
+ Runs inference with no gradients. Compatible with the
25
+ ``opponent_fn: Callable[[GameObservation], GameAction]`` interface
26
+ that KantEnvironment.reset() accepts.
27
+
28
+ Parameters
29
+ ----------
30
+ generate_fn : callable
31
+ A function ``(prompt: str) -> str`` that produces a completion.
32
+ prompt_builder : PromptBuilder, optional
33
+ Custom prompt builder. Defaults to the standard PromptBuilder.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ generate_fn: Callable[[str], str],
39
+ prompt_builder: Optional[PromptBuilder] = None,
40
+ ) -> None:
41
+ self._generate_fn = generate_fn
42
+ self._builder = prompt_builder or PromptBuilder()
43
+
44
+ def __call__(self, obs: GameObservation) -> GameAction:
45
+ """Select an action given a game observation."""
46
+ prompt = self._builder.build(obs)
47
+ completion = self._generate_fn(prompt)
48
+ action_str = parse_action(completion, obs.available_actions)
49
+ return GameAction(action=action_str)
50
+
51
+ @classmethod
52
+ def from_model(
53
+ cls,
54
+ model: object,
55
+ tokenizer: object,
56
+ max_tokens: int = MAX_ACTION_TOKENS,
57
+ ) -> FrozenOpponent:
58
+ """Create from a HuggingFace model (runs with torch.no_grad)."""
59
+ import torch
60
+
61
+ def _generate(prompt: str) -> str:
62
+ with torch.no_grad():
63
+ inputs = tokenizer(prompt, return_tensors="pt")
64
+ input_len = len(inputs["input_ids"][_ZERO])
65
+ outputs = model.generate(
66
+ **inputs, max_new_tokens=max_tokens,
67
+ )
68
+ return tokenizer.decode(
69
+ outputs[_ZERO][input_len:],
70
+ skip_special_tokens=True,
71
+ )
72
+
73
+ return cls(generate_fn=_generate)
74
+
75
+ @classmethod
76
+ def from_checkpoint(
77
+ cls,
78
+ path: str,
79
+ tokenizer_name: str,
80
+ max_tokens: int = MAX_ACTION_TOKENS,
81
+ ) -> FrozenOpponent:
82
+ """Load a frozen opponent from a saved checkpoint directory."""
83
+ from transformers import AutoModelForCausalLM, AutoTokenizer
84
+
85
+ loaded_model = AutoModelForCausalLM.from_pretrained(path)
86
+ loaded_model.eval()
87
+ loaded_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
88
+ return cls.from_model(loaded_model, loaded_tokenizer, max_tokens)
89
+
90
+ @classmethod
91
+ def from_api(
92
+ cls,
93
+ api_call_fn: Callable[[str, str], str],
94
+ ) -> FrozenOpponent:
95
+ """Create from an API-based agent (OpenAI, Anthropic, etc.)."""
96
+ return cls(
97
+ generate_fn=lambda prompt: api_call_fn(SYSTEM_PROMPT, prompt),
98
+ )
99
+
100
+
101
+ class OpponentPool:
102
+ """Maintains a pool of past model checkpoints as diverse opponents.
103
+
104
+ Samples uniformly from the pool for opponent diversity.
105
+ Evicts the oldest entry when the pool exceeds ``max_size``.
106
+
107
+ Parameters
108
+ ----------
109
+ max_size : int
110
+ Maximum number of frozen opponents to keep in the pool.
111
+ """
112
+
113
+ def __init__(self, max_size: int = SELF_PLAY_POOL_MAX_SIZE) -> None:
114
+ self._pool: List[FrozenOpponent] = []
115
+ self._max_size = max_size
116
+
117
+ def add(self, opponent: FrozenOpponent) -> None:
118
+ """Add a frozen opponent to the pool, evicting oldest if full."""
119
+ self._pool.append(opponent)
120
+ if len(self._pool) > self._max_size:
121
+ self._pool.pop(_ZERO)
122
+
123
+ def sample(self) -> FrozenOpponent:
124
+ """Return a randomly chosen opponent from the pool.
125
+
126
+ Raises
127
+ ------
128
+ IndexError
129
+ If the pool is empty.
130
+ """
131
+ if not self._pool:
132
+ raise IndexError("Cannot sample from an empty opponent pool.")
133
+ return random.choice(self._pool)
134
+
135
+ def get_opponent_fn(self) -> Callable[[GameObservation], GameAction]:
136
+ """Return a callable that uses a sampled opponent."""
137
+ return self.sample()
138
+
139
+ @property
140
+ def size(self) -> int:
141
+ """Current number of opponents in the pool."""
142
+ return len(self._pool)
train/self_play/trainer.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Self-play GRPO trainer for multi-agent training."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ import logging
7
+ import random
8
+ from typing import Any, Callable, Dict, List, Optional
9
+
10
+ from env.environment import KantEnvironment
11
+ from env.models import GameAction, GameObservation
12
+ from train.agent import LLMAgent, PromptBuilder, parse_action
13
+ from train.rewards import episode_reward
14
+ from train.trajectory import TrajectoryCollector, EpisodeTrajectory
15
+ from train.self_play.opponents import FrozenOpponent, OpponentPool
16
+ from train.self_play.config import SelfPlayConfig
17
+ from constant_definitions.train.agent_constants import SYSTEM_PROMPT
18
+ from constant_definitions.train.grpo_constants import GRPO_LOG_EVERY
19
+ from constant_definitions.game_constants import EVAL_ZERO_FLOAT
20
+ from constant_definitions.var.meta.self_play_constants import (
21
+ SELF_PLAY_COOP_WEIGHT_DENOMINATOR,
22
+ SELF_PLAY_COOP_WEIGHT_NUMERATOR,
23
+ SELF_PLAY_EXPLOIT_WEIGHT_DENOMINATOR,
24
+ SELF_PLAY_EXPLOIT_WEIGHT_NUMERATOR,
25
+ SELF_PLAY_FAIRNESS_WEIGHT_DENOMINATOR,
26
+ SELF_PLAY_FAIRNESS_WEIGHT_NUMERATOR,
27
+ SELF_PLAY_PARETO_WEIGHT_DENOMINATOR,
28
+ SELF_PLAY_PARETO_WEIGHT_NUMERATOR,
29
+ SELF_PLAY_ADAPT_WEIGHT_DENOMINATOR,
30
+ SELF_PLAY_ADAPT_WEIGHT_NUMERATOR,
31
+ SELF_PLAY_OPPONENT_LABEL,
32
+ )
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ _ZERO = int()
37
+ _ONE = int(bool(True))
38
+
39
+
40
+ def _self_play_weights() -> Dict[str, float]:
41
+ """Return reward weights tuned for self-play training."""
42
+ return {
43
+ "exploitation_resistance": (
44
+ SELF_PLAY_EXPLOIT_WEIGHT_NUMERATOR
45
+ / SELF_PLAY_EXPLOIT_WEIGHT_DENOMINATOR
46
+ ),
47
+ "cooperation_rate": (
48
+ SELF_PLAY_COOP_WEIGHT_NUMERATOR
49
+ / SELF_PLAY_COOP_WEIGHT_DENOMINATOR
50
+ ),
51
+ "pareto_efficiency": (
52
+ SELF_PLAY_PARETO_WEIGHT_NUMERATOR
53
+ / SELF_PLAY_PARETO_WEIGHT_DENOMINATOR
54
+ ),
55
+ "fairness_index": (
56
+ SELF_PLAY_FAIRNESS_WEIGHT_NUMERATOR
57
+ / SELF_PLAY_FAIRNESS_WEIGHT_DENOMINATOR
58
+ ),
59
+ "adaptability": (
60
+ SELF_PLAY_ADAPT_WEIGHT_NUMERATOR
61
+ / SELF_PLAY_ADAPT_WEIGHT_DENOMINATOR
62
+ ),
63
+ }
64
+
65
+
66
+ class SelfPlayTrainer:
67
+ """GRPO training with self-play opponents.
68
+
69
+ Training loop:
70
+ 1. Collect trajectories: training model vs frozen opponent
71
+ 2. Compute GRPO rewards from episode outcomes
72
+ 3. Update training model via TRL GRPOTrainer
73
+ 4. Periodically refresh frozen opponent from training model
74
+ 5. Add old opponent to pool for diversity
75
+
76
+ Parameters
77
+ ----------
78
+ config : SelfPlayConfig
79
+ Training configuration.
80
+ model : object
81
+ HuggingFace model to train.
82
+ tokenizer : object
83
+ Tokenizer for the model.
84
+ env : KantEnvironment, optional
85
+ Game environment instance.
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ config: SelfPlayConfig,
91
+ model: object,
92
+ tokenizer: object,
93
+ env: Optional[KantEnvironment] = None,
94
+ ) -> None:
95
+ self._config = config
96
+ self._model = model
97
+ self._tokenizer = tokenizer
98
+ self._env = env or KantEnvironment()
99
+ self._pool = OpponentPool(max_size=config.pool_max_size)
100
+ self._frozen = FrozenOpponent.from_model(model, tokenizer)
101
+ self._pool.add(self._frozen)
102
+ self._step_count = _ZERO
103
+
104
+ def _model_generate(self, prompt: str) -> str:
105
+ """Generate a completion from the training model."""
106
+ import torch
107
+
108
+ with torch.no_grad():
109
+ inputs = self._tokenizer(prompt, return_tensors="pt")
110
+ input_len = len(inputs["input_ids"][_ZERO])
111
+ outputs = self._model.generate(
112
+ **inputs,
113
+ max_new_tokens=self._config.max_completion_length,
114
+ )
115
+ return self._tokenizer.decode(
116
+ outputs[_ZERO][input_len:],
117
+ skip_special_tokens=True,
118
+ )
119
+
120
+ def collect_trajectories(
121
+ self,
122
+ games: List[str],
123
+ num_episodes: int,
124
+ ) -> List[EpisodeTrajectory]:
125
+ """Collect episodes with current frozen opponent."""
126
+ agent = LLMAgent(generate_fn=self._model_generate)
127
+ collector = TrajectoryCollector(
128
+ env=self._env,
129
+ agent=agent,
130
+ reward_fn=lambda ps, os, cr, tr: episode_reward(
131
+ ps, os, cr, tr, weights=_self_play_weights(),
132
+ ),
133
+ )
134
+ trajectories: List[EpisodeTrajectory] = []
135
+ for _ep in range(num_episodes):
136
+ game = random.choice(games)
137
+ opponent = self._pool.sample()
138
+ traj = collector.collect_episode(
139
+ game=game,
140
+ strategy=SELF_PLAY_OPPONENT_LABEL,
141
+ opponent_fn=opponent,
142
+ )
143
+ trajectories.append(traj)
144
+ return trajectories
145
+
146
+ def make_reward_fn(self) -> Callable[..., List[float]]:
147
+ """Create GRPO reward function using self-play episodes."""
148
+ pool = self._pool
149
+ env = self._env
150
+ weights = _self_play_weights()
151
+
152
+ def reward_fn(
153
+ completions: List[str],
154
+ prompts: List[str],
155
+ **kwargs: Any,
156
+ ) -> List[float]:
157
+ rewards: List[float] = []
158
+ game_keys = kwargs.get(
159
+ "game_key",
160
+ ["prisoners_dilemma"] * len(completions),
161
+ )
162
+ moves_batch = kwargs.get(
163
+ "available_moves",
164
+ [["cooperate", "defect"]] * len(completions),
165
+ )
166
+ for completion, game_key, moves in zip(
167
+ completions, game_keys, moves_batch,
168
+ ):
169
+ action_str = parse_action(completion.strip(), moves)
170
+ opponent = pool.sample()
171
+ obs = env.reset(
172
+ game=game_key, opponent_fn=opponent,
173
+ )
174
+ while not obs.done:
175
+ obs = env.step(GameAction(action=action_str))
176
+ reward = episode_reward(
177
+ obs.player_score,
178
+ obs.opponent_score,
179
+ _compute_coop_rate(obs),
180
+ obs.current_round,
181
+ weights=weights,
182
+ )
183
+ rewards.append(reward)
184
+ return rewards
185
+
186
+ return reward_fn
187
+
188
+ def refresh_opponent(self) -> None:
189
+ """Copy current training model to a new frozen opponent."""
190
+ frozen_model = copy.deepcopy(self._model)
191
+ frozen_model.eval()
192
+ new_opponent = FrozenOpponent.from_model(
193
+ frozen_model, self._tokenizer,
194
+ )
195
+ self._pool.add(new_opponent)
196
+ self._frozen = new_opponent
197
+ logger.info(
198
+ "Refreshed opponent. Pool size: %d", self._pool.size,
199
+ )
200
+
201
+ def train(self, games: List[str]) -> None:
202
+ """Main self-play training loop.
203
+
204
+ Parameters
205
+ ----------
206
+ games : list of str
207
+ Game keys to train on.
208
+ """
209
+ from datasets import Dataset
210
+ from trl import GRPOConfig, GRPOTrainer
211
+ import torch
212
+
213
+ trajectories = self.collect_trajectories(
214
+ games, self._config.warmup_episodes,
215
+ )
216
+ samples = []
217
+ for traj in trajectories:
218
+ for step in traj.steps:
219
+ messages = [
220
+ {"role": "system", "content": SYSTEM_PROMPT},
221
+ {"role": "user", "content": step.prompt},
222
+ ]
223
+ formatted = self._tokenizer.apply_chat_template(
224
+ messages, tokenize=False,
225
+ add_generation_prompt=True,
226
+ )
227
+ samples.append({
228
+ "prompt": formatted,
229
+ "game_key": traj.game,
230
+ "available_moves": ["cooperate", "defect"],
231
+ })
232
+ dataset = Dataset.from_list(samples)
233
+
234
+ reward_fn = self.make_reward_fn()
235
+
236
+ trl_config = GRPOConfig(
237
+ output_dir=self._config.output_dir,
238
+ num_generations=self._config.num_generations,
239
+ max_completion_length=self._config.max_completion_length,
240
+ per_device_train_batch_size=self._config.batch_size,
241
+ learning_rate=self._config.learning_rate,
242
+ max_steps=self._config.max_steps,
243
+ logging_steps=GRPO_LOG_EVERY,
244
+ save_steps=self._config.opponent_update_interval,
245
+ bf16=torch.cuda.is_available(),
246
+ )
247
+
248
+ trainer = GRPOTrainer(
249
+ model=self._model,
250
+ reward_funcs=reward_fn,
251
+ args=trl_config,
252
+ train_dataset=dataset,
253
+ processing_class=self._tokenizer,
254
+ )
255
+
256
+ trainer.train()
257
+ trainer.save_model(self._config.output_dir)
258
+
259
+
260
+ # ---------------------------------------------------------------------------
261
+ # Helpers
262
+ # ---------------------------------------------------------------------------
263
+
264
+ _COOPERATIVE_ACTIONS = frozenset({"cooperate", "stag", "dove"})
265
+
266
+
267
+ def _compute_coop_rate(obs: GameObservation) -> float:
268
+ """Fraction of cooperative moves in an episode."""
269
+ if not obs.history:
270
+ return EVAL_ZERO_FLOAT
271
+ total = len(obs.history)
272
+ count = _ZERO
273
+ for rnd in obs.history:
274
+ if rnd.player_action in _COOPERATIVE_ACTIONS:
275
+ count += _ONE
276
+ return count / total
train/splits.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Deterministic stratified train/eval game split."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+ from typing import Dict, FrozenSet, List, Set, Tuple
7
+
8
+ from common.games_meta.game_tags import GAME_TAGS
9
+ from constant_definitions.batch4.tag_constants import CATEGORIES
10
+ from constant_definitions.game_constants import EVAL_ZERO, EVAL_ONE
11
+ from constant_definitions.train.split_constants import (
12
+ MIN_EVAL_TAG_FRACTION_DENOMINATOR,
13
+ MIN_EVAL_TAG_FRACTION_NUMERATOR,
14
+ SPLIT_SEED,
15
+ TRAIN_FRACTION_DENOMINATOR,
16
+ TRAIN_FRACTION_NUMERATOR,
17
+ )
18
+
19
+ # Domain tags are used for stratification
20
+ _DOMAIN_TAGS: List[str] = CATEGORIES["domain"]
21
+
22
+
23
+ def get_train_eval_split(
24
+ seed: int = SPLIT_SEED,
25
+ ) -> Tuple[FrozenSet[str], FrozenSet[str]]:
26
+ """Return (train_games, eval_games) as frozen sets of game keys.
27
+
28
+ The split is deterministic for a given seed and stratified so that
29
+ every domain tag has at least ``MIN_EVAL_TAG_FRACTION`` representation
30
+ in the eval set.
31
+ """
32
+ all_games = sorted(GAME_TAGS.keys())
33
+ rng = random.Random(seed)
34
+
35
+ # Build domain -> games index
36
+ domain_to_games: Dict[str, List[str]] = {tag: [] for tag in _DOMAIN_TAGS}
37
+ for game_key in all_games:
38
+ tags = GAME_TAGS[game_key]
39
+ for dtag in _DOMAIN_TAGS:
40
+ if dtag in tags:
41
+ domain_to_games[dtag].append(game_key)
42
+
43
+ # Guarantee minimum eval representation per domain
44
+ eval_set: Set[str] = set()
45
+ for dtag in _DOMAIN_TAGS:
46
+ games_with_tag = domain_to_games[dtag]
47
+ if not games_with_tag:
48
+ continue
49
+ min_eval = _min_eval_count(len(games_with_tag))
50
+ already_in_eval = [g for g in games_with_tag if g in eval_set]
51
+ needed = min_eval - len(already_in_eval)
52
+ if needed > EVAL_ZERO:
53
+ candidates = [g for g in games_with_tag if g not in eval_set]
54
+ rng.shuffle(candidates)
55
+ for g in candidates[:needed]:
56
+ eval_set.add(g)
57
+
58
+ # Fill remaining eval slots up to target size
59
+ total = len(all_games)
60
+ target_train = (total * TRAIN_FRACTION_NUMERATOR) // TRAIN_FRACTION_DENOMINATOR
61
+ target_eval = total - target_train
62
+ remaining = [g for g in all_games if g not in eval_set]
63
+ rng.shuffle(remaining)
64
+ slots_to_fill = target_eval - len(eval_set)
65
+ if slots_to_fill > EVAL_ZERO:
66
+ for g in remaining[:slots_to_fill]:
67
+ eval_set.add(g)
68
+
69
+ train_set = frozenset(g for g in all_games if g not in eval_set)
70
+ return train_set, frozenset(eval_set)
71
+
72
+
73
+ def _min_eval_count(tag_total: int) -> int:
74
+ """Minimum number of games with a given tag that must be in eval."""
75
+ _numer = tag_total * MIN_EVAL_TAG_FRACTION_NUMERATOR
76
+ result = (_numer + MIN_EVAL_TAG_FRACTION_DENOMINATOR - EVAL_ONE) // MIN_EVAL_TAG_FRACTION_DENOMINATOR
77
+ return max(result, EVAL_ONE)
train/train.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """KantBench GRPO Training Script.
2
+
3
+ Trains a language model to play 2-player game theory games optimally
4
+ using Group Relative Policy Optimization (GRPO) via TRL.
5
+
6
+ The KantBench environment runs as a remote OpenEnv server (HF Space):
7
+ - Each GRPO completion is a single move
8
+ - The reward function plays a FULL multi-round episode using that move
9
+ as the agent's consistent strategy via the OpenEnv client
10
+ - The composite reward (payoff + cooperation + Pareto efficiency + fairness)
11
+ becomes the GRPO signal
12
+
13
+ Supports the full KantBench game library including:
14
+ - 90+ base 2-player games and 3 N-player games
15
+ - 9 pre-registered meta-games (rule_proposal, rule_signal, gossip)
16
+ - Dynamic variant composition (cheap_talk, exit, binding_commitment,
17
+ constitutional, proposer_responder, noisy_actions, noisy_payoffs)
18
+
19
+ Usage:
20
+ python -m train.train --model Qwen/Qwen2.5-7B-Instruct --max-steps 200
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import argparse
26
+ import logging
27
+ import random
28
+ from typing import Any, List
29
+
30
+ import torch
31
+ from datasets import Dataset
32
+ from trl import GRPOConfig, GRPOTrainer
33
+ from transformers import AutoTokenizer
34
+
35
+ from common.games import GAMES
36
+ from common.strategies import STRATEGIES as STRATEGY_REGISTRY
37
+ from spaces.kant.client import KantBenchEnv
38
+ from spaces.kant.models import KantBenchAction, KantBenchObservation
39
+ from train.agent import parse_action
40
+ from train.rewards import episode_reward
41
+ from train.splits import get_train_eval_split
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # Config
47
+ # ---------------------------------------------------------------------------
48
+
49
+ KANTBENCH_URL = "https://openenv-community-kantbench.hf.space"
50
+
51
+ SYSTEM_PROMPT = (
52
+ "You are playing a game-theory game. Analyse the situation and choose "
53
+ "the best action. Respond with ONLY the action name, nothing else."
54
+ )
55
+
56
+ # Variants that can be dynamically composed on top of base games.
57
+ # These are applied server-side via the variant= reset parameter.
58
+ TRAINABLE_VARIANTS = [
59
+ "cheap_talk",
60
+ "exit",
61
+ "binding_commitment",
62
+ "constitutional",
63
+ "noisy_actions",
64
+ "noisy_payoffs",
65
+ "rule_proposal",
66
+ "rule_signal",
67
+ "gossip",
68
+ ]
69
+
70
+ # Base games suitable for variant composition (2-player matrix games).
71
+ VARIANT_BASE_GAMES = [
72
+ "prisoners_dilemma",
73
+ "stag_hunt",
74
+ "hawk_dove",
75
+ ]
76
+
77
+ # Fraction of dataset samples that use dynamic variant composition.
78
+ VARIANT_FRACTION = 0.3
79
+
80
+
81
+ # ---------------------------------------------------------------------------
82
+ # Helpers to bridge KantBenchObservation -> training code
83
+ # ---------------------------------------------------------------------------
84
+
85
+
86
+ def _obs_cooperation_rate(obs: KantBenchObservation) -> float:
87
+ """Compute cooperation rate from a KantBenchObservation's history."""
88
+ if not obs.history:
89
+ return 0.0
90
+ coop_actions = {"cooperate", "stag", "dove", "contribute"}
91
+ coop_count = sum(
92
+ 1 for h in obs.history
93
+ if any(ca in h.get("your_move", "") for ca in coop_actions)
94
+ )
95
+ return coop_count / len(obs.history)
96
+
97
+
98
+ def _build_prompt(obs: KantBenchObservation) -> str:
99
+ """Build a structured prompt from a KantBenchObservation.
100
+
101
+ Mirrors PromptBuilder.build() but works with the OpenEnv client's
102
+ observation format.
103
+ """
104
+ sections: list[str] = []
105
+
106
+ # Game section
107
+ sections.append(
108
+ f"[Game]\n{obs.game_name}\n{obs.game_description}"
109
+ )
110
+
111
+ # History section
112
+ if obs.history:
113
+ history_lines: list[str] = []
114
+ for h in obs.history[-5:]: # Last 5 rounds
115
+ line = (
116
+ f"Round {h.get('round', '?')}"
117
+ f" | You played: {h.get('your_move', '?')}"
118
+ f" | Opponent played: {h.get('opponent_move', '?')}"
119
+ f" | Your payoff: {h.get('your_payoff', '?')}"
120
+ f" | Opp payoff: {h.get('opponent_payoff', '?')}"
121
+ )
122
+ history_lines.append(line)
123
+ sections.append("[History]\n" + "\n".join(history_lines))
124
+
125
+ # Scores section
126
+ sections.append(
127
+ f"[Scores]\nYour score: {obs.cumulative_score}"
128
+ f"\nRound: {obs.round_number} of {obs.max_rounds}"
129
+ )
130
+
131
+ # Available actions
132
+ action_lines = [f"- {a}" for a in obs.available_moves]
133
+ sections.append("[Available Actions]\n" + "\n".join(action_lines))
134
+
135
+ # Instruction
136
+ sections.append(f"[Instruction]\n{SYSTEM_PROMPT}")
137
+
138
+ return "\n\n".join(sections)
139
+
140
+ # ---------------------------------------------------------------------------
141
+ # Dataset generation using PromptBuilder
142
+ # ---------------------------------------------------------------------------
143
+
144
+
145
+ def build_dataset(
146
+ base_url: str,
147
+ n_samples: int = 1000,
148
+ games: list[str] | None = None,
149
+ strategies: list[str] | None = None,
150
+ variant_fraction: float = VARIANT_FRACTION,
151
+ ) -> Dataset:
152
+ """Generate diverse game theory prompts for GRPO training.
153
+
154
+ Connects to the KantBench OpenEnv server to generate real observations,
155
+ then builds structured prompts from diverse game states.
156
+
157
+ A fraction of samples use dynamic variant composition (cheap_talk,
158
+ constitutional, gossip, etc.) to train on meta-gaming scenarios.
159
+ """
160
+ game_keys = games or list(GAMES.keys())
161
+ strat_names = strategies or list(STRATEGY_REGISTRY.keys())
162
+ samples = []
163
+
164
+ with KantBenchEnv(base_url=base_url) as env:
165
+ attempts = 0
166
+ while len(samples) < n_samples:
167
+ attempts += 1
168
+
169
+ # Decide whether to use a variant
170
+ use_variant = random.random() < variant_fraction
171
+ if use_variant:
172
+ game_key = random.choice(VARIANT_BASE_GAMES)
173
+ variant = random.choice(TRAINABLE_VARIANTS)
174
+ else:
175
+ game_key = random.choice(game_keys)
176
+ variant = None
177
+
178
+ strategy = random.choice(strat_names)
179
+
180
+ try:
181
+ # Reset env — pass variant for dynamic composition
182
+ reset_kwargs = {"game": game_key, "strategy": strategy}
183
+ if variant:
184
+ reset_kwargs["variant"] = variant
185
+
186
+ result = env.reset(**reset_kwargs)
187
+ obs = result.observation
188
+
189
+ # Play 0..N-1 random rounds to create diverse game states
190
+ max_rounds = obs.max_rounds
191
+ rounds_to_play = random.randint(0, max(max_rounds - 1, 0))
192
+ for _ in range(rounds_to_play):
193
+ move = random.choice(obs.available_moves)
194
+ result = env.step(KantBenchAction(move=move))
195
+ obs = result.observation
196
+ if result.done:
197
+ break
198
+
199
+ if result.done:
200
+ # Replay without filling all rounds
201
+ result = env.reset(**reset_kwargs)
202
+ obs = result.observation
203
+
204
+ prompt = _build_prompt(obs)
205
+
206
+ samples.append({
207
+ "prompt": prompt,
208
+ "game_key": game_key,
209
+ "strategy": strategy,
210
+ "variant": variant or "",
211
+ "available_moves": list(obs.available_moves),
212
+ "rounds_remaining": obs.max_rounds - obs.round_number,
213
+ })
214
+ except (RuntimeError, ConnectionError, Exception) as exc:
215
+ logger.debug(
216
+ "Skipping %s/%s (variant=%s): %s",
217
+ game_key, strategy, variant, exc,
218
+ )
219
+ continue
220
+
221
+ return Dataset.from_list(samples)
222
+
223
+
224
+ # ---------------------------------------------------------------------------
225
+ # Reward function — full episode rollout
226
+ # ---------------------------------------------------------------------------
227
+
228
+
229
+ def make_reward_fn(base_url: str):
230
+ """Returns a GRPO reward function that plays full episodes via OpenEnv.
231
+
232
+ For each completion:
233
+ 1. Parse the move from the LLM output
234
+ 2. Reset the KantBench server with the correct game/strategy/variant
235
+ 3. Play the FULL episode using the parsed move as a consistent strategy
236
+ 4. Compute composite reward: payoff + cooperation + Pareto + fairness
237
+ """
238
+ env = KantBenchEnv(base_url=base_url)
239
+ env.connect()
240
+
241
+ def reward_fn(
242
+ completions: list[str],
243
+ prompts: list[str],
244
+ **kwargs: Any,
245
+ ) -> list[float]:
246
+ rewards = []
247
+ game_keys = kwargs.get("game_key", ["prisoners_dilemma"] * len(completions))
248
+ strategies = kwargs.get("strategy", ["tit_for_tat"] * len(completions))
249
+ variants = kwargs.get("variant", [""] * len(completions))
250
+ available_moves_batch = kwargs.get(
251
+ "available_moves", [["cooperate", "defect"]] * len(completions)
252
+ )
253
+
254
+ for completion, game_key, strategy, variant, moves in zip(
255
+ completions, game_keys, strategies, variants, available_moves_batch
256
+ ):
257
+ # Parse move from LLM output
258
+ action_str = parse_action(completion.strip(), moves)
259
+
260
+ try:
261
+ # Play a full episode using this move as a consistent strategy
262
+ reset_kwargs = {"game": game_key, "strategy": strategy}
263
+ if variant:
264
+ reset_kwargs["variant"] = variant
265
+
266
+ result = env.reset(**reset_kwargs)
267
+ while not result.done:
268
+ result = env.step(KantBenchAction(move=action_str))
269
+
270
+ obs = result.observation
271
+
272
+ # Compute cooperation rate from observation history
273
+ coop_rate = _obs_cooperation_rate(obs)
274
+
275
+ # Composite reward from the reward module
276
+ # opponent_score not directly available in KantBenchObservation,
277
+ # approximate from history
278
+ opp_score = sum(
279
+ h.get("opponent_payoff", 0.0) for h in obs.history
280
+ )
281
+ reward = episode_reward(
282
+ player_score=obs.cumulative_score,
283
+ opponent_score=opp_score,
284
+ cooperation_rate=coop_rate,
285
+ total_rounds=obs.round_number,
286
+ )
287
+ rewards.append(reward)
288
+
289
+ except (ValueError, KeyError, RuntimeError, ConnectionError) as exc:
290
+ logger.debug("Reward error for %s/%s: %s", game_key, action_str, exc)
291
+ rewards.append(-1.0)
292
+
293
+ return rewards
294
+
295
+ return reward_fn
296
+
297
+
298
+ # ---------------------------------------------------------------------------
299
+ # Main
300
+ # ---------------------------------------------------------------------------
301
+
302
+
303
+ def parse_args():
304
+ p = argparse.ArgumentParser(description="KantBench GRPO Training")
305
+ p.add_argument("--model", default="Qwen/Qwen2.5-7B-Instruct")
306
+ p.add_argument("--output-dir", default="./kantbench-grpo")
307
+ p.add_argument("--env-url", default=KANTBENCH_URL,
308
+ help="KantBench OpenEnv server URL")
309
+ p.add_argument("--episodes", type=int, default=1000, help="Training dataset size")
310
+ p.add_argument("--num-generations", type=int, default=8, help="GRPO group size")
311
+ p.add_argument("--batch-size", type=int, default=4)
312
+ p.add_argument("--grad-accum", type=int, default=4)
313
+ p.add_argument("--lr", type=float, default=5e-6)
314
+ p.add_argument("--max-steps", type=int, default=500)
315
+ p.add_argument("--report-to", default="wandb", help="wandb, tensorboard, or none")
316
+ p.add_argument("--push-to-hub", action="store_true")
317
+ p.add_argument("--hub-model-id", default="jtowarek/kantbench-qwen2.5-7b")
318
+ p.add_argument("--use-train-split", action="store_true",
319
+ help="Use stratified train/eval split (eval games held out)")
320
+ p.add_argument("--variant-fraction", type=float, default=VARIANT_FRACTION,
321
+ help="Fraction of samples using dynamic variant composition")
322
+ return p.parse_args()
323
+
324
+
325
+ def main():
326
+ args = parse_args()
327
+ logging.basicConfig(level=logging.INFO)
328
+
329
+ print(f"Loading model: {args.model}")
330
+ print(f"Output: {args.output_dir}")
331
+ print(f"OpenEnv server: {args.env_url}")
332
+
333
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
334
+ if tokenizer.pad_token is None:
335
+ tokenizer.pad_token = tokenizer.eos_token
336
+
337
+ # Optionally use stratified train/eval split
338
+ train_games = None
339
+ if args.use_train_split:
340
+ train_set, eval_set = get_train_eval_split()
341
+ train_games = sorted(train_set)
342
+ print(f"Using stratified split: {len(train_games)} train, {len(eval_set)} eval games")
343
+
344
+ dataset = build_dataset(
345
+ args.env_url, args.episodes, games=train_games,
346
+ variant_fraction=args.variant_fraction,
347
+ )
348
+ variant_count = sum(1 for v in dataset["variant"] if v)
349
+ print(f"Dataset: {len(dataset)} prompts across {len(GAMES)} games")
350
+ print(f" Variant samples: {variant_count} ({variant_count*100//max(len(dataset),1)}%)")
351
+
352
+ # Format prompts with chat template
353
+ def format_prompt(example):
354
+ messages = [
355
+ {"role": "system", "content": SYSTEM_PROMPT},
356
+ {"role": "user", "content": example["prompt"]},
357
+ ]
358
+ return {
359
+ "prompt": tokenizer.apply_chat_template(
360
+ messages, tokenize=False, add_generation_prompt=True
361
+ )
362
+ }
363
+
364
+ dataset = dataset.map(format_prompt)
365
+
366
+ reward_fn = make_reward_fn(args.env_url)
367
+
368
+ config = GRPOConfig(
369
+ output_dir=args.output_dir,
370
+ num_generations=args.num_generations,
371
+ max_completion_length=32,
372
+ per_device_train_batch_size=args.batch_size,
373
+ gradient_accumulation_steps=args.grad_accum,
374
+ learning_rate=args.lr,
375
+ max_steps=args.max_steps,
376
+ logging_steps=10,
377
+ save_steps=100,
378
+ bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8,
379
+ fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
380
+ report_to=args.report_to,
381
+ push_to_hub=args.push_to_hub,
382
+ hub_model_id=args.hub_model_id if args.push_to_hub else None,
383
+ )
384
+
385
+ trainer = GRPOTrainer(
386
+ model=args.model,
387
+ reward_funcs=reward_fn,
388
+ args=config,
389
+ train_dataset=dataset,
390
+ processing_class=tokenizer,
391
+ )
392
+
393
+ print("Starting GRPO training...")
394
+ print(f" Reward: composite (payoff + cooperation + Pareto + fairness)")
395
+ print(f" Episode: full multi-round rollout via OpenEnv @ {args.env_url}")
396
+ print(f" Variants: {args.variant_fraction*100:.0f}% of samples use dynamic composition")
397
+ trainer.train()
398
+ trainer.save_model(args.output_dir)
399
+ print(f"Done. Model saved to {args.output_dir}")
400
+
401
+
402
+ if __name__ == "__main__":
403
+ main()
train/trajectory.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Trajectory collection for training data generation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Callable, Dict, List, Optional
7
+
8
+ from env.models import GameAction, GameObservation, RoundResult
9
+ from env.environment import KantEnvironment
10
+ from constant_definitions.game_constants import EVAL_ZERO_FLOAT
11
+
12
+
13
+ @dataclass
14
+ class StepRecord:
15
+ """A single step within an episode trajectory."""
16
+
17
+ prompt: str
18
+ completion: str
19
+ action: str
20
+ reward: float
21
+ player_payoff: float
22
+ opponent_payoff: float
23
+ round_number: int
24
+
25
+
26
+ @dataclass
27
+ class EpisodeTrajectory:
28
+ """Complete trajectory of one episode."""
29
+
30
+ game: str
31
+ strategy: str
32
+ steps: List[StepRecord] = field(default_factory=list)
33
+ episode_reward: float = EVAL_ZERO_FLOAT
34
+ player_score: float = EVAL_ZERO_FLOAT
35
+ opponent_score: float = EVAL_ZERO_FLOAT
36
+ cooperation_rate: float = EVAL_ZERO_FLOAT
37
+ rounds_played: int = int()
38
+ metrics: Dict[str, float] = field(default_factory=dict)
39
+
40
+
41
+ class TrajectoryCollector:
42
+ """Runs episodes and collects trajectories for training.
43
+
44
+ Parameters
45
+ ----------
46
+ env : KantEnvironment
47
+ The game environment instance.
48
+ agent : LLMAgent
49
+ An agent with ``last_prompt`` / ``last_completion`` properties,
50
+ callable with ``(GameObservation) -> GameAction``.
51
+ reward_fn : callable, optional
52
+ Function(player_score, opponent_score, cooperation_rate, rounds) -> float.
53
+ step_reward_fn : callable, optional
54
+ Function(player_payoff, opponent_payoff, payoff_min, payoff_max) -> float.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ env: KantEnvironment,
60
+ agent: Any,
61
+ reward_fn: Optional[Callable[..., float]] = None,
62
+ step_reward_fn: Optional[Callable[..., float]] = None,
63
+ ) -> None:
64
+ self._env = env
65
+ self._agent = agent
66
+ self._reward_fn = reward_fn
67
+ self._step_reward_fn = step_reward_fn
68
+
69
+ def collect_episode(
70
+ self,
71
+ game: str,
72
+ strategy: str = "tit_for_tat",
73
+ opponent_fn: Optional[Callable] = None,
74
+ ) -> EpisodeTrajectory:
75
+ """Run a single episode and return its trajectory."""
76
+ if opponent_fn is not None:
77
+ obs = self._env.reset(game=game, opponent_fn=opponent_fn)
78
+ else:
79
+ obs = self._env.reset(game=game, strategy=strategy)
80
+ steps: List[StepRecord] = []
81
+
82
+ while not obs.done:
83
+ action = self._agent(obs)
84
+
85
+ # Capture prompt/completion from agent
86
+ prompt = getattr(self._agent, "last_prompt", "")
87
+ completion = getattr(self._agent, "last_completion", "")
88
+
89
+ next_obs = self._env.step(action)
90
+
91
+ # Compute step reward
92
+ step_reward = EVAL_ZERO_FLOAT
93
+ if self._step_reward_fn is not None and next_obs.last_round is not None:
94
+ step_reward = self._step_reward_fn(
95
+ next_obs.last_round.player_payoff,
96
+ next_obs.last_round.opponent_payoff,
97
+ EVAL_ZERO_FLOAT,
98
+ EVAL_ZERO_FLOAT,
99
+ )
100
+
101
+ # Record step
102
+ last_rnd = next_obs.last_round
103
+ steps.append(StepRecord(
104
+ prompt=prompt,
105
+ completion=completion,
106
+ action=action.action,
107
+ reward=step_reward,
108
+ player_payoff=(
109
+ last_rnd.player_payoff if last_rnd is not None
110
+ else EVAL_ZERO_FLOAT
111
+ ),
112
+ opponent_payoff=(
113
+ last_rnd.opponent_payoff if last_rnd is not None
114
+ else EVAL_ZERO_FLOAT
115
+ ),
116
+ round_number=next_obs.current_round,
117
+ ))
118
+ obs = next_obs
119
+
120
+ # Compute cooperation rate (reusing tournament logic pattern)
121
+ coop_rate = _compute_cooperation_rate(obs)
122
+
123
+ # Compute episode reward
124
+ ep_reward = EVAL_ZERO_FLOAT
125
+ if self._reward_fn is not None:
126
+ ep_reward = self._reward_fn(
127
+ obs.player_score,
128
+ obs.opponent_score,
129
+ coop_rate,
130
+ obs.current_round,
131
+ )
132
+
133
+ return EpisodeTrajectory(
134
+ game=game,
135
+ strategy=strategy,
136
+ steps=steps,
137
+ episode_reward=ep_reward,
138
+ player_score=obs.player_score,
139
+ opponent_score=obs.opponent_score,
140
+ cooperation_rate=coop_rate,
141
+ rounds_played=obs.current_round,
142
+ )
143
+
144
+ def collect_batch(
145
+ self,
146
+ games: List[str],
147
+ strategies: Optional[List[str]] = None,
148
+ episodes_per_pair: int = int(bool(True)),
149
+ opponent_fn: Optional[Callable] = None,
150
+ ) -> List[EpisodeTrajectory]:
151
+ """Collect trajectories for all (game, strategy) combinations.
152
+
153
+ If *opponent_fn* is provided, self-play mode is used: only
154
+ games are iterated (strategies are ignored).
155
+ """
156
+ trajectories: List[EpisodeTrajectory] = []
157
+ if opponent_fn is not None:
158
+ for game in games:
159
+ for _ep in range(episodes_per_pair):
160
+ traj = self.collect_episode(
161
+ game, opponent_fn=opponent_fn,
162
+ )
163
+ trajectories.append(traj)
164
+ else:
165
+ strats = strategies or ["tit_for_tat"]
166
+ for game in games:
167
+ for strategy in strats:
168
+ for _ep in range(episodes_per_pair):
169
+ traj = self.collect_episode(game, strategy)
170
+ trajectories.append(traj)
171
+ return trajectories
172
+
173
+
174
+ # ---------------------------------------------------------------------------
175
+ # Helpers
176
+ # ---------------------------------------------------------------------------
177
+
178
+ _COOPERATIVE_ACTIONS = frozenset({"cooperate", "stag", "dove"})
179
+ _ECONOMIC_PREFIXES = frozenset({"offer", "invest", "contribute"})
180
+
181
+ _ZERO = int()
182
+ _ONE = int(bool(True))
183
+ _TWO = _ONE + _ONE
184
+
185
+
186
+ def _compute_cooperation_rate(obs: GameObservation) -> float:
187
+ """Fraction of cooperative moves in an episode."""
188
+ if not obs.history:
189
+ return EVAL_ZERO_FLOAT
190
+ total = len(obs.history)
191
+ cooperative_count = _ZERO
192
+ first_action = obs.history[_ZERO].player_action
193
+ prefix = first_action.split("_")[_ZERO]
194
+ is_economic = prefix in _ECONOMIC_PREFIXES
195
+ if is_economic:
196
+ median_idx = len(obs.available_actions) // _TWO
197
+ for rnd in obs.history:
198
+ act = rnd.player_action
199
+ if act in obs.available_actions:
200
+ if obs.available_actions.index(act) >= median_idx:
201
+ cooperative_count += _ONE
202
+ else:
203
+ for rnd in obs.history:
204
+ if rnd.player_action in _COOPERATIVE_ACTIONS:
205
+ cooperative_count += _ONE
206
+ return cooperative_count / total