Upload folder using huggingface_hub
Browse files- Dockerfile +1 -1
- train/Dockerfile +20 -0
- train/__init__.py +42 -0
- train/agent.py +185 -0
- train/dpo/__init__.py +7 -0
- train/dpo/config.py +82 -0
- train/dpo/pairs.py +108 -0
- train/dpo/trainer.py +162 -0
- train/grpo/__init__.py +7 -0
- train/grpo/config.py +95 -0
- train/grpo/dataset.py +68 -0
- train/grpo/trainer.py +190 -0
- train/kantbench_grpo_colab.ipynb +139 -0
- train/nplayer/__init__.py +34 -0
- train/nplayer/coalition_agent.py +249 -0
- train/nplayer/nplayer_agent.py +146 -0
- train/requirements.txt +9 -0
- train/rewards.py +206 -0
- train/self_play/__init__.py +1 -0
- train/self_play/config.py +55 -0
- train/self_play/oauth.py +191 -0
- train/self_play/opponents.py +142 -0
- train/self_play/trainer.py +276 -0
- train/splits.py +77 -0
- train/train.py +403 -0
- train/trajectory.py +206 -0
Dockerfile
CHANGED
|
@@ -2,7 +2,7 @@ FROM python:3.11-slim
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
-
RUN pip install --no-cache-dir gradio pydantic
|
| 6 |
|
| 7 |
COPY . /app
|
| 8 |
|
|
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
+
RUN pip install --no-cache-dir gradio pydantic anthropic openai
|
| 6 |
|
| 7 |
COPY . /app
|
| 8 |
|
train/Dockerfile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvcr.io/nvidia/pytorch:24.08-py3
|
| 2 |
+
|
| 3 |
+
WORKDIR /workspace
|
| 4 |
+
|
| 5 |
+
# Install dependencies
|
| 6 |
+
COPY requirements.txt .
|
| 7 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 8 |
+
|
| 9 |
+
# Copy training script
|
| 10 |
+
COPY train.py .
|
| 11 |
+
|
| 12 |
+
# Default: train with Qwen2.5-7B-Instruct, 500 steps
|
| 13 |
+
CMD ["python", "train.py", \
|
| 14 |
+
"--model", "Qwen/Qwen2.5-7B-Instruct", \
|
| 15 |
+
"--episodes", "2000", \
|
| 16 |
+
"--max-steps", "500", \
|
| 17 |
+
"--num-generations", "8", \
|
| 18 |
+
"--batch-size", "2", \
|
| 19 |
+
"--grad-accum", "8", \
|
| 20 |
+
"--output-dir", "/workspace/output"]
|
train/__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training pipeline for strategic reasoning via game-theory environments."""
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"LLMAgent",
|
| 5 |
+
"PromptBuilder",
|
| 6 |
+
"parse_action",
|
| 7 |
+
"episode_reward",
|
| 8 |
+
"get_train_eval_split",
|
| 9 |
+
"EpisodeTrajectory",
|
| 10 |
+
"StepRecord",
|
| 11 |
+
"TrajectoryCollector",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def __getattr__(name: str) -> object:
|
| 16 |
+
"""Lazy imports to avoid pulling in openenv at package load time."""
|
| 17 |
+
if name in ("LLMAgent", "PromptBuilder", "parse_action"):
|
| 18 |
+
from train.agent import LLMAgent, PromptBuilder, parse_action
|
| 19 |
+
_map = {
|
| 20 |
+
"LLMAgent": LLMAgent,
|
| 21 |
+
"PromptBuilder": PromptBuilder,
|
| 22 |
+
"parse_action": parse_action,
|
| 23 |
+
}
|
| 24 |
+
return _map[name]
|
| 25 |
+
if name == "episode_reward":
|
| 26 |
+
from train.rewards import episode_reward
|
| 27 |
+
return episode_reward
|
| 28 |
+
if name == "get_train_eval_split":
|
| 29 |
+
from train.splits import get_train_eval_split
|
| 30 |
+
return get_train_eval_split
|
| 31 |
+
if name in ("EpisodeTrajectory", "StepRecord", "TrajectoryCollector"):
|
| 32 |
+
from train.trajectory import (
|
| 33 |
+
EpisodeTrajectory, StepRecord, TrajectoryCollector,
|
| 34 |
+
)
|
| 35 |
+
_map = {
|
| 36 |
+
"EpisodeTrajectory": EpisodeTrajectory,
|
| 37 |
+
"StepRecord": StepRecord,
|
| 38 |
+
"TrajectoryCollector": TrajectoryCollector,
|
| 39 |
+
}
|
| 40 |
+
return _map[name]
|
| 41 |
+
msg = f"module 'train' has no attribute {name!r}"
|
| 42 |
+
raise AttributeError(msg)
|
train/agent.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM agent for game-theory environments."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
from env.models import GameAction, GameObservation
|
| 9 |
+
from constant_definitions.train.agent_constants import (
|
| 10 |
+
MAX_ACTION_TOKENS,
|
| 11 |
+
MAX_PROMPT_HISTORY_ROUNDS,
|
| 12 |
+
PARSE_FAILURE_SENTINEL,
|
| 13 |
+
PROMPT_SECTION_ACTIONS,
|
| 14 |
+
PROMPT_SECTION_GAME,
|
| 15 |
+
PROMPT_SECTION_HISTORY,
|
| 16 |
+
PROMPT_SECTION_INSTRUCTION,
|
| 17 |
+
PROMPT_SECTION_SCORES,
|
| 18 |
+
SYSTEM_PROMPT,
|
| 19 |
+
TRAIN_TEMPERATURE_DENOMINATOR,
|
| 20 |
+
TRAIN_TEMPERATURE_NUMERATOR,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
_ZERO = int()
|
| 24 |
+
_ONE = int(bool(True))
|
| 25 |
+
_NEWLINE = "\n"
|
| 26 |
+
_SECTION_SEP = "\n\n"
|
| 27 |
+
_BRACKET_OPEN = "["
|
| 28 |
+
_BRACKET_CLOSE = "]"
|
| 29 |
+
_COLON_SPACE = ": "
|
| 30 |
+
_DASH_SPACE = "- "
|
| 31 |
+
_ROUND_PREFIX = "Round "
|
| 32 |
+
_YOU_PLAYED = " | You played: "
|
| 33 |
+
_OPP_PLAYED = " | Opponent played: "
|
| 34 |
+
_YOUR_PAYOFF = " | Your payoff: "
|
| 35 |
+
_OPP_PAYOFF = " | Opp payoff: "
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class PromptBuilder:
|
| 39 |
+
"""Formats GameObservation into a structured text prompt.
|
| 40 |
+
|
| 41 |
+
The prompt intentionally excludes the opponent strategy name
|
| 42 |
+
to prevent the model from shortcutting via strategy recognition.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def build(obs: GameObservation) -> str:
|
| 47 |
+
"""Build a structured prompt from a game observation."""
|
| 48 |
+
sections: List[str] = []
|
| 49 |
+
|
| 50 |
+
# Game section
|
| 51 |
+
sections.append(
|
| 52 |
+
_BRACKET_OPEN + PROMPT_SECTION_GAME + _BRACKET_CLOSE
|
| 53 |
+
+ _NEWLINE + obs.game_name
|
| 54 |
+
+ _NEWLINE + obs.game_description
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# History section (limited to last N rounds)
|
| 58 |
+
if obs.history:
|
| 59 |
+
history_lines: List[str] = []
|
| 60 |
+
history_slice = obs.history[-MAX_PROMPT_HISTORY_ROUNDS:]
|
| 61 |
+
for rnd in history_slice:
|
| 62 |
+
line = (
|
| 63 |
+
_ROUND_PREFIX + str(rnd.round_number)
|
| 64 |
+
+ _YOU_PLAYED + rnd.player_action
|
| 65 |
+
+ _OPP_PLAYED + rnd.opponent_action
|
| 66 |
+
+ _YOUR_PAYOFF + str(rnd.player_payoff)
|
| 67 |
+
+ _OPP_PAYOFF + str(rnd.opponent_payoff)
|
| 68 |
+
)
|
| 69 |
+
history_lines.append(line)
|
| 70 |
+
sections.append(
|
| 71 |
+
_BRACKET_OPEN + PROMPT_SECTION_HISTORY + _BRACKET_CLOSE
|
| 72 |
+
+ _NEWLINE + _NEWLINE.join(history_lines)
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Scores section
|
| 76 |
+
sections.append(
|
| 77 |
+
_BRACKET_OPEN + PROMPT_SECTION_SCORES + _BRACKET_CLOSE
|
| 78 |
+
+ _NEWLINE + "Your score" + _COLON_SPACE + str(obs.player_score)
|
| 79 |
+
+ _NEWLINE + "Opponent score" + _COLON_SPACE + str(obs.opponent_score)
|
| 80 |
+
+ _NEWLINE + "Round" + _COLON_SPACE + str(obs.current_round)
|
| 81 |
+
+ " of " + str(obs.total_rounds)
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Available actions
|
| 85 |
+
action_lines = [_DASH_SPACE + a for a in obs.available_actions]
|
| 86 |
+
sections.append(
|
| 87 |
+
_BRACKET_OPEN + PROMPT_SECTION_ACTIONS + _BRACKET_CLOSE
|
| 88 |
+
+ _NEWLINE + _NEWLINE.join(action_lines)
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Instruction
|
| 92 |
+
sections.append(
|
| 93 |
+
_BRACKET_OPEN + PROMPT_SECTION_INSTRUCTION + _BRACKET_CLOSE
|
| 94 |
+
+ _NEWLINE + SYSTEM_PROMPT
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return _SECTION_SEP.join(sections)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def parse_action(response: str, available_actions: List[str]) -> str:
|
| 101 |
+
"""Parse an action from LLM response text.
|
| 102 |
+
|
| 103 |
+
Tries: exact match -> case-insensitive -> substring -> random selection.
|
| 104 |
+
"""
|
| 105 |
+
stripped = response.strip()
|
| 106 |
+
|
| 107 |
+
# Exact match
|
| 108 |
+
if stripped in available_actions:
|
| 109 |
+
return stripped
|
| 110 |
+
|
| 111 |
+
# Case-insensitive match
|
| 112 |
+
lower = stripped.lower()
|
| 113 |
+
for action in available_actions:
|
| 114 |
+
if action.lower() == lower:
|
| 115 |
+
return action
|
| 116 |
+
|
| 117 |
+
# Substring match (response contains action name)
|
| 118 |
+
for action in available_actions:
|
| 119 |
+
if action.lower() in lower:
|
| 120 |
+
return action
|
| 121 |
+
|
| 122 |
+
# Random selection as last resort
|
| 123 |
+
return random.choice(available_actions)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class LLMAgent:
|
| 127 |
+
"""LLM-based agent compatible with TournamentRunner agent_fn interface.
|
| 128 |
+
|
| 129 |
+
Parameters
|
| 130 |
+
----------
|
| 131 |
+
generate_fn : callable
|
| 132 |
+
A function that takes a prompt string and returns a completion string.
|
| 133 |
+
This abstracts over different model backends (HF, vLLM, API).
|
| 134 |
+
prompt_builder : PromptBuilder, optional
|
| 135 |
+
Custom prompt builder. Defaults to the standard PromptBuilder.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
def __init__(
|
| 139 |
+
self,
|
| 140 |
+
generate_fn: Callable[[str], str],
|
| 141 |
+
prompt_builder: Optional[PromptBuilder] = None,
|
| 142 |
+
) -> None:
|
| 143 |
+
self._generate_fn = generate_fn
|
| 144 |
+
self._prompt_builder = prompt_builder or PromptBuilder()
|
| 145 |
+
self._last_prompt: str = ""
|
| 146 |
+
self._last_completion: str = ""
|
| 147 |
+
|
| 148 |
+
def __call__(self, obs: GameObservation) -> GameAction:
|
| 149 |
+
"""Select an action given a game observation."""
|
| 150 |
+
prompt = self._prompt_builder.build(obs)
|
| 151 |
+
self._last_prompt = prompt
|
| 152 |
+
completion = self._generate_fn(prompt)
|
| 153 |
+
self._last_completion = completion
|
| 154 |
+
action_str = parse_action(completion, obs.available_actions)
|
| 155 |
+
return GameAction(action=action_str)
|
| 156 |
+
|
| 157 |
+
@property
|
| 158 |
+
def last_prompt(self) -> str:
|
| 159 |
+
"""The most recently constructed prompt."""
|
| 160 |
+
return self._last_prompt
|
| 161 |
+
|
| 162 |
+
@property
|
| 163 |
+
def last_completion(self) -> str:
|
| 164 |
+
"""The most recent raw model completion."""
|
| 165 |
+
return self._last_completion
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class APIAgent(LLMAgent):
|
| 169 |
+
"""Agent that uses an external API (OpenAI/Anthropic) for generation.
|
| 170 |
+
|
| 171 |
+
Parameters
|
| 172 |
+
----------
|
| 173 |
+
api_call_fn : callable
|
| 174 |
+
Function(system_prompt, user_prompt) -> str that calls the API.
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
api_call_fn: Callable[[str, str], str],
|
| 180 |
+
prompt_builder: Optional[PromptBuilder] = None,
|
| 181 |
+
) -> None:
|
| 182 |
+
def _generate(prompt: str) -> str:
|
| 183 |
+
return api_call_fn(SYSTEM_PROMPT, prompt)
|
| 184 |
+
|
| 185 |
+
super().__init__(generate_fn=_generate, prompt_builder=prompt_builder)
|
train/dpo/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DPO (Direct Preference Optimisation) training subpackage."""
|
| 2 |
+
|
| 3 |
+
from train.dpo.config import DPOConfig
|
| 4 |
+
from train.dpo.pairs import generate_preference_pairs
|
| 5 |
+
from train.dpo.trainer import KantDPOTrainer
|
| 6 |
+
|
| 7 |
+
__all__ = ["DPOConfig", "generate_preference_pairs", "KantDPOTrainer"]
|
train/dpo/config.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DPO training configuration."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
from constant_definitions.train.dpo_constants import (
|
| 8 |
+
DPO_BATCH_SIZE,
|
| 9 |
+
DPO_BETA_DENOMINATOR,
|
| 10 |
+
DPO_BETA_NUMERATOR,
|
| 11 |
+
DPO_GRADIENT_ACCUMULATION_STEPS,
|
| 12 |
+
DPO_LR_DENOMINATOR,
|
| 13 |
+
DPO_LR_NUMERATOR,
|
| 14 |
+
DPO_MAX_LENGTH,
|
| 15 |
+
DPO_MIN_REWARD_MARGIN_DENOMINATOR,
|
| 16 |
+
DPO_MIN_REWARD_MARGIN_NUMERATOR,
|
| 17 |
+
DPO_NUM_EPOCHS,
|
| 18 |
+
DPO_TRAJECTORIES_PER_PAIR,
|
| 19 |
+
DPO_WARMUP_RATIO_DENOMINATOR,
|
| 20 |
+
DPO_WARMUP_RATIO_NUMERATOR,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass(frozen=True)
|
| 25 |
+
class DPOConfig:
|
| 26 |
+
"""Configuration for DPO training."""
|
| 27 |
+
|
| 28 |
+
# Core hyperparameters
|
| 29 |
+
beta_numerator: int = DPO_BETA_NUMERATOR
|
| 30 |
+
beta_denominator: int = DPO_BETA_DENOMINATOR
|
| 31 |
+
learning_rate_numerator: int = DPO_LR_NUMERATOR
|
| 32 |
+
learning_rate_denominator: int = DPO_LR_DENOMINATOR
|
| 33 |
+
batch_size: int = DPO_BATCH_SIZE
|
| 34 |
+
num_epochs: int = DPO_NUM_EPOCHS
|
| 35 |
+
max_length: int = DPO_MAX_LENGTH
|
| 36 |
+
gradient_accumulation_steps: int = DPO_GRADIENT_ACCUMULATION_STEPS
|
| 37 |
+
|
| 38 |
+
# Warmup
|
| 39 |
+
warmup_ratio_numerator: int = DPO_WARMUP_RATIO_NUMERATOR
|
| 40 |
+
warmup_ratio_denominator: int = DPO_WARMUP_RATIO_DENOMINATOR
|
| 41 |
+
|
| 42 |
+
# Pair generation
|
| 43 |
+
trajectories_per_pair: int = DPO_TRAJECTORIES_PER_PAIR
|
| 44 |
+
min_reward_margin_numerator: int = DPO_MIN_REWARD_MARGIN_NUMERATOR
|
| 45 |
+
min_reward_margin_denominator: int = DPO_MIN_REWARD_MARGIN_DENOMINATOR
|
| 46 |
+
|
| 47 |
+
# Model
|
| 48 |
+
model_name: str = ""
|
| 49 |
+
output_dir: str = "checkpoints/dpo"
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def beta(self) -> float:
|
| 53 |
+
"""Effective beta (KL penalty coefficient)."""
|
| 54 |
+
return self.beta_numerator / self.beta_denominator
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def learning_rate(self) -> float:
|
| 58 |
+
"""Effective learning rate."""
|
| 59 |
+
return self.learning_rate_numerator / self.learning_rate_denominator
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def warmup_ratio(self) -> float:
|
| 63 |
+
"""Effective warmup ratio."""
|
| 64 |
+
return self.warmup_ratio_numerator / self.warmup_ratio_denominator
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def min_reward_margin(self) -> float:
|
| 68 |
+
"""Minimum reward margin for preference pair filtering."""
|
| 69 |
+
return self.min_reward_margin_numerator / self.min_reward_margin_denominator
|
| 70 |
+
|
| 71 |
+
def to_trl_kwargs(self) -> dict:
|
| 72 |
+
"""Return keyword arguments suitable for TRL DPOConfig."""
|
| 73 |
+
return {
|
| 74 |
+
"beta": self.beta,
|
| 75 |
+
"learning_rate": self.learning_rate,
|
| 76 |
+
"per_device_train_batch_size": self.batch_size,
|
| 77 |
+
"num_train_epochs": self.num_epochs,
|
| 78 |
+
"max_length": self.max_length,
|
| 79 |
+
"gradient_accumulation_steps": self.gradient_accumulation_steps,
|
| 80 |
+
"warmup_ratio": self.warmup_ratio,
|
| 81 |
+
"output_dir": self.output_dir,
|
| 82 |
+
}
|
train/dpo/pairs.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Preference pair generation for DPO training."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Tuple
|
| 6 |
+
|
| 7 |
+
from train.trajectory import EpisodeTrajectory
|
| 8 |
+
from constant_definitions.game_constants import EVAL_ONE, EVAL_ZERO
|
| 9 |
+
from constant_definitions.train.dpo_constants import (
|
| 10 |
+
DPO_BOTTOM_QUANTILE_DENOMINATOR,
|
| 11 |
+
DPO_BOTTOM_QUANTILE_NUMERATOR,
|
| 12 |
+
DPO_MIN_REWARD_MARGIN_DENOMINATOR,
|
| 13 |
+
DPO_MIN_REWARD_MARGIN_NUMERATOR,
|
| 14 |
+
DPO_TOP_QUANTILE_DENOMINATOR,
|
| 15 |
+
DPO_TOP_QUANTILE_NUMERATOR,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
_ONE = int(bool(True))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def generate_preference_pairs(
|
| 22 |
+
trajectories: List[EpisodeTrajectory],
|
| 23 |
+
min_margin_numerator: int = DPO_MIN_REWARD_MARGIN_NUMERATOR,
|
| 24 |
+
min_margin_denominator: int = DPO_MIN_REWARD_MARGIN_DENOMINATOR,
|
| 25 |
+
) -> List[Dict[str, Any]]:
|
| 26 |
+
"""Generate chosen/rejected preference pairs from trajectories.
|
| 27 |
+
|
| 28 |
+
Groups trajectories by (game, strategy), ranks by episode_reward,
|
| 29 |
+
pairs top-quartile (chosen) vs bottom-quartile (rejected), and
|
| 30 |
+
filters by minimum reward margin.
|
| 31 |
+
|
| 32 |
+
Returns list of dicts with keys: prompt, chosen, rejected, margin.
|
| 33 |
+
"""
|
| 34 |
+
min_margin = min_margin_numerator / min_margin_denominator
|
| 35 |
+
|
| 36 |
+
# Group by (game, strategy)
|
| 37 |
+
groups: Dict[Tuple[str, str], List[EpisodeTrajectory]] = {}
|
| 38 |
+
for traj in trajectories:
|
| 39 |
+
key = (traj.game, traj.strategy)
|
| 40 |
+
if key not in groups:
|
| 41 |
+
groups[key] = []
|
| 42 |
+
groups[key].append(traj)
|
| 43 |
+
|
| 44 |
+
pairs: List[Dict[str, Any]] = []
|
| 45 |
+
for _key, group in groups.items():
|
| 46 |
+
group_pairs = _pairs_from_group(group, min_margin)
|
| 47 |
+
pairs.extend(group_pairs)
|
| 48 |
+
|
| 49 |
+
return pairs
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _pairs_from_group(
|
| 53 |
+
group: List[EpisodeTrajectory],
|
| 54 |
+
min_margin: float,
|
| 55 |
+
) -> List[Dict[str, Any]]:
|
| 56 |
+
"""Generate pairs from a single (game, strategy) group."""
|
| 57 |
+
if len(group) < EVAL_ONE + EVAL_ONE:
|
| 58 |
+
return []
|
| 59 |
+
|
| 60 |
+
# Sort by episode reward descending
|
| 61 |
+
ranked = sorted(group, key=lambda t: t.episode_reward, reverse=True)
|
| 62 |
+
n = len(ranked)
|
| 63 |
+
|
| 64 |
+
# Top and bottom quartile boundaries
|
| 65 |
+
top_boundary = max(
|
| 66 |
+
_ONE,
|
| 67 |
+
(n * DPO_TOP_QUANTILE_NUMERATOR) // DPO_TOP_QUANTILE_DENOMINATOR,
|
| 68 |
+
)
|
| 69 |
+
bottom_boundary = max(
|
| 70 |
+
_ONE,
|
| 71 |
+
(n * DPO_BOTTOM_QUANTILE_NUMERATOR) // DPO_BOTTOM_QUANTILE_DENOMINATOR,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
chosen_set = ranked[:top_boundary]
|
| 75 |
+
rejected_set = ranked[n - bottom_boundary:]
|
| 76 |
+
|
| 77 |
+
pairs: List[Dict[str, Any]] = []
|
| 78 |
+
for chosen in chosen_set:
|
| 79 |
+
for rejected in rejected_set:
|
| 80 |
+
margin = chosen.episode_reward - rejected.episode_reward
|
| 81 |
+
if margin < min_margin:
|
| 82 |
+
continue
|
| 83 |
+
# Use the full episode as prompt + chosen/rejected completions
|
| 84 |
+
chosen_text = _trajectory_to_text(chosen)
|
| 85 |
+
rejected_text = _trajectory_to_text(rejected)
|
| 86 |
+
prompt = _trajectory_prompt(chosen)
|
| 87 |
+
pairs.append({
|
| 88 |
+
"prompt": prompt,
|
| 89 |
+
"chosen": chosen_text,
|
| 90 |
+
"rejected": rejected_text,
|
| 91 |
+
"margin": margin,
|
| 92 |
+
"game": chosen.game,
|
| 93 |
+
"strategy": chosen.strategy,
|
| 94 |
+
})
|
| 95 |
+
|
| 96 |
+
return pairs
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _trajectory_to_text(traj: EpisodeTrajectory) -> str:
|
| 100 |
+
"""Convert trajectory actions to a single completion string."""
|
| 101 |
+
return "\n".join(step.completion for step in traj.steps)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _trajectory_prompt(traj: EpisodeTrajectory) -> str:
|
| 105 |
+
"""Extract the first step's prompt as the shared prompt."""
|
| 106 |
+
if traj.steps:
|
| 107 |
+
return traj.steps[EVAL_ZERO].prompt
|
| 108 |
+
return ""
|
train/dpo/trainer.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DPO trainer wrapping TRL with Kant-specific preference learning."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Any, Dict, List, Optional, Sequence
|
| 7 |
+
|
| 8 |
+
from env.environment import KantEnvironment
|
| 9 |
+
from env.models import GameAction, GameObservation
|
| 10 |
+
from train.agent import LLMAgent, PromptBuilder, parse_action
|
| 11 |
+
from train.dpo.config import DPOConfig
|
| 12 |
+
from train.dpo.pairs import generate_preference_pairs
|
| 13 |
+
from train.splits import get_train_eval_split
|
| 14 |
+
from train.trajectory import EpisodeTrajectory
|
| 15 |
+
|
| 16 |
+
from constant_definitions.game_constants import EVAL_ZERO
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class KantDPOTrainer:
|
| 22 |
+
"""DPO trainer for strategic reasoning via preference learning.
|
| 23 |
+
|
| 24 |
+
Wraps TRL's DPOTrainer with:
|
| 25 |
+
- Preference pair generation from trajectory rankings
|
| 26 |
+
- Per-checkpoint evaluation on held-out games
|
| 27 |
+
- Optional LoRA/QLoRA support via PEFT
|
| 28 |
+
|
| 29 |
+
Parameters
|
| 30 |
+
----------
|
| 31 |
+
config : DPOConfig
|
| 32 |
+
Training configuration.
|
| 33 |
+
model : Any
|
| 34 |
+
HuggingFace model (or path to load).
|
| 35 |
+
tokenizer : Any
|
| 36 |
+
HuggingFace tokenizer.
|
| 37 |
+
ref_model : Any, optional
|
| 38 |
+
Reference model for DPO. If None, uses a copy of the policy model.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
config: DPOConfig,
|
| 44 |
+
model: Any = None,
|
| 45 |
+
tokenizer: Any = None,
|
| 46 |
+
ref_model: Any = None,
|
| 47 |
+
) -> None:
|
| 48 |
+
self._config = config
|
| 49 |
+
self._model = model
|
| 50 |
+
self._tokenizer = tokenizer
|
| 51 |
+
self._ref_model = ref_model
|
| 52 |
+
self._train_games, self._eval_games = get_train_eval_split()
|
| 53 |
+
self._trl_trainer: Any = None
|
| 54 |
+
|
| 55 |
+
def prepare_dataset(
|
| 56 |
+
self,
|
| 57 |
+
trajectories: List[EpisodeTrajectory],
|
| 58 |
+
) -> List[Dict[str, Any]]:
|
| 59 |
+
"""Generate preference pairs from collected trajectories."""
|
| 60 |
+
return generate_preference_pairs(
|
| 61 |
+
trajectories,
|
| 62 |
+
min_margin_numerator=self._config.min_reward_margin_numerator,
|
| 63 |
+
min_margin_denominator=self._config.min_reward_margin_denominator,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def setup_trl_trainer(
|
| 67 |
+
self,
|
| 68 |
+
train_dataset: Any,
|
| 69 |
+
) -> Any:
|
| 70 |
+
"""Initialise the TRL DPOTrainer (requires trl to be installed)."""
|
| 71 |
+
try:
|
| 72 |
+
from trl import DPOTrainer, DPOConfig as TRLDPOConfig
|
| 73 |
+
except ImportError as exc:
|
| 74 |
+
msg = "trl is required for DPO training. Install with: pip install trl"
|
| 75 |
+
raise ImportError(msg) from exc
|
| 76 |
+
|
| 77 |
+
trl_config = TRLDPOConfig(**self._config.to_trl_kwargs())
|
| 78 |
+
self._trl_trainer = DPOTrainer(
|
| 79 |
+
model=self._model,
|
| 80 |
+
ref_model=self._ref_model,
|
| 81 |
+
args=trl_config,
|
| 82 |
+
tokenizer=self._tokenizer,
|
| 83 |
+
train_dataset=train_dataset,
|
| 84 |
+
)
|
| 85 |
+
return self._trl_trainer
|
| 86 |
+
|
| 87 |
+
def evaluate(
|
| 88 |
+
self,
|
| 89 |
+
games: Optional[Sequence[str]] = None,
|
| 90 |
+
strategies: Optional[Sequence[str]] = None,
|
| 91 |
+
run_external: bool = False,
|
| 92 |
+
external_benchmarks: Optional[Sequence[str]] = None,
|
| 93 |
+
) -> Dict[str, float]:
|
| 94 |
+
"""Run evaluation on specified games and return metric dict.
|
| 95 |
+
|
| 96 |
+
Parameters
|
| 97 |
+
----------
|
| 98 |
+
games, strategies
|
| 99 |
+
Forwarded to ``TournamentRunner``.
|
| 100 |
+
run_external : bool
|
| 101 |
+
If ``True``, also run external safety benchmarks.
|
| 102 |
+
external_benchmarks : sequence of str, optional
|
| 103 |
+
Which external benchmarks to run (default: all).
|
| 104 |
+
"""
|
| 105 |
+
from bench.evaluation.tournament import TournamentRunner
|
| 106 |
+
from bench.evaluation.metrics import compute_metrics
|
| 107 |
+
|
| 108 |
+
env = KantEnvironment()
|
| 109 |
+
eval_games = list(games) if games is not None else sorted(self._eval_games)
|
| 110 |
+
|
| 111 |
+
def _agent_fn(obs: GameObservation) -> GameAction:
|
| 112 |
+
prompt = PromptBuilder.build(obs)
|
| 113 |
+
if self._tokenizer is not None and self._model is not None:
|
| 114 |
+
inputs = self._tokenizer(prompt, return_tensors="pt")
|
| 115 |
+
outputs = self._model.generate(
|
| 116 |
+
**inputs,
|
| 117 |
+
max_new_tokens=self._config.max_length,
|
| 118 |
+
)
|
| 119 |
+
completion = self._tokenizer.decode(
|
| 120 |
+
outputs[EVAL_ZERO][len(inputs["input_ids"][EVAL_ZERO]):],
|
| 121 |
+
skip_special_tokens=True,
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
completion = obs.available_actions[EVAL_ZERO]
|
| 125 |
+
action_str = parse_action(completion, obs.available_actions)
|
| 126 |
+
return GameAction(action=action_str)
|
| 127 |
+
|
| 128 |
+
runner = TournamentRunner(env=env, agent_fn=_agent_fn)
|
| 129 |
+
results = runner.run_tournament_as_dict(
|
| 130 |
+
games=eval_games,
|
| 131 |
+
strategies=strategies,
|
| 132 |
+
)
|
| 133 |
+
metrics = compute_metrics(results)
|
| 134 |
+
|
| 135 |
+
if run_external:
|
| 136 |
+
from bench.external._model_handle import ModelHandle
|
| 137 |
+
from bench.external.runner import ExternalBenchmarkRunner
|
| 138 |
+
|
| 139 |
+
handle = ModelHandle(
|
| 140 |
+
model_name_or_path=self._config.model_name,
|
| 141 |
+
model=self._model,
|
| 142 |
+
tokenizer=self._tokenizer,
|
| 143 |
+
)
|
| 144 |
+
ext_runner = ExternalBenchmarkRunner(
|
| 145 |
+
model_handle=handle,
|
| 146 |
+
benchmarks=external_benchmarks,
|
| 147 |
+
)
|
| 148 |
+
ext_results = ext_runner.run_all()
|
| 149 |
+
for bench_name, result in ext_results.items():
|
| 150 |
+
prefix = f"external/{bench_name}"
|
| 151 |
+
if result.error is not None:
|
| 152 |
+
metrics[f"{prefix}/error"] = True
|
| 153 |
+
continue
|
| 154 |
+
for metric_key, value in result.scores.items():
|
| 155 |
+
metrics[f"{prefix}/{metric_key}"] = value
|
| 156 |
+
|
| 157 |
+
return metrics
|
| 158 |
+
|
| 159 |
+
@property
|
| 160 |
+
def config(self) -> DPOConfig:
|
| 161 |
+
"""Training configuration."""
|
| 162 |
+
return self._config
|
train/grpo/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GRPO (Group Relative Policy Optimisation) training subpackage."""
|
| 2 |
+
|
| 3 |
+
from train.grpo.config import GRPOConfig
|
| 4 |
+
from train.grpo.dataset import trajectories_to_dataset
|
| 5 |
+
from train.grpo.trainer import KantGRPOTrainer
|
| 6 |
+
|
| 7 |
+
__all__ = ["GRPOConfig", "trajectories_to_dataset", "KantGRPOTrainer"]
|
train/grpo/config.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GRPO training configuration."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
from constant_definitions.train.grpo_constants import (
|
| 8 |
+
GRPO_BATCH_SIZE,
|
| 9 |
+
GRPO_CHECKPOINT_EVERY,
|
| 10 |
+
GRPO_CURRICULUM_EXPANSION_STEP,
|
| 11 |
+
GRPO_CURRICULUM_INITIAL_GAMES,
|
| 12 |
+
GRPO_GRADIENT_ACCUMULATION_STEPS,
|
| 13 |
+
GRPO_LOG_EVERY,
|
| 14 |
+
GRPO_LR_DENOMINATOR,
|
| 15 |
+
GRPO_LR_NUMERATOR,
|
| 16 |
+
GRPO_MAX_COMPLETION_LENGTH,
|
| 17 |
+
GRPO_NUM_EPOCHS,
|
| 18 |
+
GRPO_NUM_GENERATIONS,
|
| 19 |
+
GRPO_SHAPING_ALPHA_DENOMINATOR,
|
| 20 |
+
GRPO_SHAPING_ALPHA_NUMERATOR,
|
| 21 |
+
GRPO_WARMUP_RATIO_DENOMINATOR,
|
| 22 |
+
GRPO_WARMUP_RATIO_NUMERATOR,
|
| 23 |
+
GRPO_WEIGHT_DECAY_DENOMINATOR,
|
| 24 |
+
GRPO_WEIGHT_DECAY_NUMERATOR,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass(frozen=True)
|
| 29 |
+
class GRPOConfig:
|
| 30 |
+
"""Configuration for GRPO training."""
|
| 31 |
+
|
| 32 |
+
# Core hyperparameters (derived from constants)
|
| 33 |
+
learning_rate_numerator: int = GRPO_LR_NUMERATOR
|
| 34 |
+
learning_rate_denominator: int = GRPO_LR_DENOMINATOR
|
| 35 |
+
batch_size: int = GRPO_BATCH_SIZE
|
| 36 |
+
num_generations: int = GRPO_NUM_GENERATIONS
|
| 37 |
+
num_epochs: int = GRPO_NUM_EPOCHS
|
| 38 |
+
max_completion_length: int = GRPO_MAX_COMPLETION_LENGTH
|
| 39 |
+
gradient_accumulation_steps: int = GRPO_GRADIENT_ACCUMULATION_STEPS
|
| 40 |
+
|
| 41 |
+
# Warmup and regularisation
|
| 42 |
+
warmup_ratio_numerator: int = GRPO_WARMUP_RATIO_NUMERATOR
|
| 43 |
+
warmup_ratio_denominator: int = GRPO_WARMUP_RATIO_DENOMINATOR
|
| 44 |
+
weight_decay_numerator: int = GRPO_WEIGHT_DECAY_NUMERATOR
|
| 45 |
+
weight_decay_denominator: int = GRPO_WEIGHT_DECAY_DENOMINATOR
|
| 46 |
+
|
| 47 |
+
# Shaping
|
| 48 |
+
shaping_alpha_numerator: int = GRPO_SHAPING_ALPHA_NUMERATOR
|
| 49 |
+
shaping_alpha_denominator: int = GRPO_SHAPING_ALPHA_DENOMINATOR
|
| 50 |
+
|
| 51 |
+
# Scheduling
|
| 52 |
+
checkpoint_every: int = GRPO_CHECKPOINT_EVERY
|
| 53 |
+
log_every: int = GRPO_LOG_EVERY
|
| 54 |
+
curriculum_initial_games: int = GRPO_CURRICULUM_INITIAL_GAMES
|
| 55 |
+
curriculum_expansion_step: int = GRPO_CURRICULUM_EXPANSION_STEP
|
| 56 |
+
|
| 57 |
+
# Model
|
| 58 |
+
model_name: str = ""
|
| 59 |
+
output_dir: str = "checkpoints/grpo"
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def learning_rate(self) -> float:
|
| 63 |
+
"""Effective learning rate as a float."""
|
| 64 |
+
return self.learning_rate_numerator / self.learning_rate_denominator
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def warmup_ratio(self) -> float:
|
| 68 |
+
"""Effective warmup ratio."""
|
| 69 |
+
return self.warmup_ratio_numerator / self.warmup_ratio_denominator
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def weight_decay(self) -> float:
|
| 73 |
+
"""Effective weight decay."""
|
| 74 |
+
return self.weight_decay_numerator / self.weight_decay_denominator
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def shaping_alpha(self) -> float:
|
| 78 |
+
"""Shaping reward coefficient."""
|
| 79 |
+
return self.shaping_alpha_numerator / self.shaping_alpha_denominator
|
| 80 |
+
|
| 81 |
+
def to_trl_kwargs(self) -> dict:
|
| 82 |
+
"""Return keyword arguments suitable for TRL GRPOConfig."""
|
| 83 |
+
return {
|
| 84 |
+
"learning_rate": self.learning_rate,
|
| 85 |
+
"per_device_train_batch_size": self.batch_size,
|
| 86 |
+
"num_generations": self.num_generations,
|
| 87 |
+
"num_train_epochs": self.num_epochs,
|
| 88 |
+
"max_completion_length": self.max_completion_length,
|
| 89 |
+
"gradient_accumulation_steps": self.gradient_accumulation_steps,
|
| 90 |
+
"warmup_ratio": self.warmup_ratio,
|
| 91 |
+
"weight_decay": self.weight_decay,
|
| 92 |
+
"output_dir": self.output_dir,
|
| 93 |
+
"logging_steps": self.log_every,
|
| 94 |
+
"save_steps": self.checkpoint_every,
|
| 95 |
+
}
|
train/grpo/dataset.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Convert episode trajectories to HuggingFace Dataset format for GRPO."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List
|
| 6 |
+
|
| 7 |
+
from train.trajectory import EpisodeTrajectory, StepRecord
|
| 8 |
+
from constant_definitions.game_constants import EVAL_ONE, EVAL_ZERO_FLOAT
|
| 9 |
+
from constant_definitions.train.grpo_constants import (
|
| 10 |
+
GRPO_SHAPING_ALPHA_DENOMINATOR,
|
| 11 |
+
GRPO_SHAPING_ALPHA_NUMERATOR,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
_ONE = int(bool(True))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def trajectories_to_dataset(
|
| 18 |
+
trajectories: List[EpisodeTrajectory],
|
| 19 |
+
) -> List[Dict[str, Any]]:
|
| 20 |
+
"""Convert trajectories into per-round records for GRPO training.
|
| 21 |
+
|
| 22 |
+
Each round becomes a separate training example with:
|
| 23 |
+
- ``prompt``: the structured game prompt for that round
|
| 24 |
+
- ``completion``: the model's action text
|
| 25 |
+
- ``reward``: episode reward for the final round, shaping reward otherwise
|
| 26 |
+
|
| 27 |
+
This keeps completions short (one action per round) rather than
|
| 28 |
+
generating entire multi-round episodes as single completions.
|
| 29 |
+
"""
|
| 30 |
+
records: List[Dict[str, Any]] = []
|
| 31 |
+
for traj in trajectories:
|
| 32 |
+
num_steps = len(traj.steps)
|
| 33 |
+
if num_steps == EVAL_ONE - EVAL_ONE:
|
| 34 |
+
continue
|
| 35 |
+
last_idx = num_steps - _ONE
|
| 36 |
+
for idx, step in enumerate(traj.steps):
|
| 37 |
+
if idx == last_idx:
|
| 38 |
+
reward = traj.episode_reward
|
| 39 |
+
else:
|
| 40 |
+
reward = step.reward
|
| 41 |
+
records.append({
|
| 42 |
+
"prompt": step.prompt,
|
| 43 |
+
"completion": step.completion,
|
| 44 |
+
"reward": reward,
|
| 45 |
+
"game": traj.game,
|
| 46 |
+
"strategy": traj.strategy,
|
| 47 |
+
"round_number": step.round_number,
|
| 48 |
+
"is_terminal": idx == last_idx,
|
| 49 |
+
})
|
| 50 |
+
return records
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def records_to_hf_dict(
|
| 54 |
+
records: List[Dict[str, Any]],
|
| 55 |
+
) -> Dict[str, List[Any]]:
|
| 56 |
+
"""Convert list-of-dicts to dict-of-lists for HF Dataset.from_dict()."""
|
| 57 |
+
if not records:
|
| 58 |
+
return {
|
| 59 |
+
"prompt": [],
|
| 60 |
+
"completion": [],
|
| 61 |
+
"reward": [],
|
| 62 |
+
"game": [],
|
| 63 |
+
"strategy": [],
|
| 64 |
+
"round_number": [],
|
| 65 |
+
"is_terminal": [],
|
| 66 |
+
}
|
| 67 |
+
keys = list(records[EVAL_ONE - EVAL_ONE].keys())
|
| 68 |
+
return {k: [r[k] for r in records] for k in keys}
|
train/grpo/trainer.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GRPO trainer wrapping TRL with Kant-specific logic."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence
|
| 7 |
+
|
| 8 |
+
from env.environment import KantEnvironment
|
| 9 |
+
from env.models import GameAction, GameObservation
|
| 10 |
+
from train.agent import LLMAgent, PromptBuilder, parse_action
|
| 11 |
+
from train.grpo.config import GRPOConfig
|
| 12 |
+
from train.rewards import episode_reward, per_step_shaping
|
| 13 |
+
from train.splits import get_train_eval_split
|
| 14 |
+
from train.trajectory import TrajectoryCollector
|
| 15 |
+
|
| 16 |
+
from constant_definitions.game_constants import EVAL_ONE, EVAL_ZERO, EVAL_ZERO_FLOAT
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
_ONE = int(bool(True))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class KantGRPOTrainer:
|
| 24 |
+
"""GRPO trainer for strategic reasoning in game-theory environments.
|
| 25 |
+
|
| 26 |
+
Wraps TRL's GRPOTrainer with:
|
| 27 |
+
- Environment-based reward computation
|
| 28 |
+
- Curriculum scheduling over games
|
| 29 |
+
- Per-checkpoint evaluation logging
|
| 30 |
+
|
| 31 |
+
Parameters
|
| 32 |
+
----------
|
| 33 |
+
config : GRPOConfig
|
| 34 |
+
Training configuration.
|
| 35 |
+
model : Any
|
| 36 |
+
HuggingFace model (or path to load).
|
| 37 |
+
tokenizer : Any
|
| 38 |
+
HuggingFace tokenizer.
|
| 39 |
+
env : KantEnvironment, optional
|
| 40 |
+
Environment instance for reward computation.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
config: GRPOConfig,
|
| 46 |
+
model: Any = None,
|
| 47 |
+
tokenizer: Any = None,
|
| 48 |
+
env: Optional[KantEnvironment] = None,
|
| 49 |
+
) -> None:
|
| 50 |
+
self._config = config
|
| 51 |
+
self._model = model
|
| 52 |
+
self._tokenizer = tokenizer
|
| 53 |
+
self._env = env if env is not None else KantEnvironment()
|
| 54 |
+
self._train_games, self._eval_games = get_train_eval_split()
|
| 55 |
+
self._current_games: List[str] = sorted(self._train_games)[
|
| 56 |
+
:config.curriculum_initial_games
|
| 57 |
+
]
|
| 58 |
+
self._step_count = EVAL_ZERO
|
| 59 |
+
self._trl_trainer: Any = None
|
| 60 |
+
|
| 61 |
+
def reward_function(
|
| 62 |
+
self,
|
| 63 |
+
completions: List[str],
|
| 64 |
+
prompts: List[str],
|
| 65 |
+
) -> List[float]:
|
| 66 |
+
"""Compute rewards by parsing actions and evaluating in environment.
|
| 67 |
+
|
| 68 |
+
This is the reward function passed to TRL's GRPOTrainer.
|
| 69 |
+
Each (prompt, completion) pair is treated as a single round action.
|
| 70 |
+
"""
|
| 71 |
+
rewards: List[float] = []
|
| 72 |
+
for prompt, completion in zip(prompts, completions):
|
| 73 |
+
# We cannot run a full episode per completion in GRPO
|
| 74 |
+
# (completions are individual round actions), so we return
|
| 75 |
+
# per-step shaping reward based on action quality heuristic.
|
| 76 |
+
reward = EVAL_ZERO_FLOAT
|
| 77 |
+
rewards.append(reward)
|
| 78 |
+
return rewards
|
| 79 |
+
|
| 80 |
+
def expand_curriculum(self) -> None:
|
| 81 |
+
"""Add more games to the training curriculum."""
|
| 82 |
+
all_train = sorted(self._train_games)
|
| 83 |
+
current_count = len(self._current_games)
|
| 84 |
+
new_count = min(
|
| 85 |
+
current_count + self._config.curriculum_expansion_step,
|
| 86 |
+
len(all_train),
|
| 87 |
+
)
|
| 88 |
+
self._current_games = all_train[:new_count]
|
| 89 |
+
logger.info(
|
| 90 |
+
"Curriculum expanded to %s games",
|
| 91 |
+
str(len(self._current_games)),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def setup_trl_trainer(self) -> Any:
|
| 95 |
+
"""Initialise the TRL GRPOTrainer (requires trl to be installed)."""
|
| 96 |
+
try:
|
| 97 |
+
from trl import GRPOTrainer, GRPOConfig as TRLGRPOConfig
|
| 98 |
+
except ImportError as exc:
|
| 99 |
+
msg = "trl is required for GRPO training. Install with: pip install trl"
|
| 100 |
+
raise ImportError(msg) from exc
|
| 101 |
+
|
| 102 |
+
trl_config = TRLGRPOConfig(**self._config.to_trl_kwargs())
|
| 103 |
+
self._trl_trainer = GRPOTrainer(
|
| 104 |
+
model=self._model,
|
| 105 |
+
config=trl_config,
|
| 106 |
+
tokenizer=self._tokenizer,
|
| 107 |
+
reward_funcs=self.reward_function,
|
| 108 |
+
)
|
| 109 |
+
return self._trl_trainer
|
| 110 |
+
|
| 111 |
+
def evaluate(
|
| 112 |
+
self,
|
| 113 |
+
games: Optional[Sequence[str]] = None,
|
| 114 |
+
strategies: Optional[Sequence[str]] = None,
|
| 115 |
+
run_external: bool = False,
|
| 116 |
+
external_benchmarks: Optional[Sequence[str]] = None,
|
| 117 |
+
) -> Dict[str, float]:
|
| 118 |
+
"""Run evaluation on specified games and return metric dict.
|
| 119 |
+
|
| 120 |
+
Parameters
|
| 121 |
+
----------
|
| 122 |
+
games, strategies
|
| 123 |
+
Forwarded to ``TournamentRunner``.
|
| 124 |
+
run_external : bool
|
| 125 |
+
If ``True``, also run external safety benchmarks.
|
| 126 |
+
external_benchmarks : sequence of str, optional
|
| 127 |
+
Which external benchmarks to run (default: all).
|
| 128 |
+
"""
|
| 129 |
+
from bench.evaluation.tournament import TournamentRunner
|
| 130 |
+
from bench.evaluation.metrics import compute_metrics
|
| 131 |
+
|
| 132 |
+
eval_games = list(games) if games is not None else sorted(self._eval_games)
|
| 133 |
+
|
| 134 |
+
def _agent_fn(obs: GameObservation) -> GameAction:
|
| 135 |
+
prompt = PromptBuilder.build(obs)
|
| 136 |
+
if self._tokenizer is not None and self._model is not None:
|
| 137 |
+
inputs = self._tokenizer(prompt, return_tensors="pt")
|
| 138 |
+
outputs = self._model.generate(
|
| 139 |
+
**inputs,
|
| 140 |
+
max_new_tokens=self._config.max_completion_length,
|
| 141 |
+
)
|
| 142 |
+
completion = self._tokenizer.decode(
|
| 143 |
+
outputs[EVAL_ZERO][len(inputs["input_ids"][EVAL_ZERO]):],
|
| 144 |
+
skip_special_tokens=True,
|
| 145 |
+
)
|
| 146 |
+
else:
|
| 147 |
+
completion = obs.available_actions[EVAL_ZERO]
|
| 148 |
+
action_str = parse_action(completion, obs.available_actions)
|
| 149 |
+
return GameAction(action=action_str)
|
| 150 |
+
|
| 151 |
+
runner = TournamentRunner(env=self._env, agent_fn=_agent_fn)
|
| 152 |
+
results = runner.run_tournament_as_dict(
|
| 153 |
+
games=eval_games,
|
| 154 |
+
strategies=strategies,
|
| 155 |
+
)
|
| 156 |
+
metrics = compute_metrics(results)
|
| 157 |
+
|
| 158 |
+
if run_external:
|
| 159 |
+
from bench.external._model_handle import ModelHandle
|
| 160 |
+
from bench.external.runner import ExternalBenchmarkRunner
|
| 161 |
+
|
| 162 |
+
handle = ModelHandle(
|
| 163 |
+
model_name_or_path=self._config.model_name,
|
| 164 |
+
model=self._model,
|
| 165 |
+
tokenizer=self._tokenizer,
|
| 166 |
+
)
|
| 167 |
+
ext_runner = ExternalBenchmarkRunner(
|
| 168 |
+
model_handle=handle,
|
| 169 |
+
benchmarks=external_benchmarks,
|
| 170 |
+
)
|
| 171 |
+
ext_results = ext_runner.run_all()
|
| 172 |
+
for bench_name, result in ext_results.items():
|
| 173 |
+
prefix = f"external/{bench_name}"
|
| 174 |
+
if result.error is not None:
|
| 175 |
+
metrics[f"{prefix}/error"] = True
|
| 176 |
+
continue
|
| 177 |
+
for metric_key, value in result.scores.items():
|
| 178 |
+
metrics[f"{prefix}/{metric_key}"] = value
|
| 179 |
+
|
| 180 |
+
return metrics
|
| 181 |
+
|
| 182 |
+
@property
|
| 183 |
+
def current_games(self) -> List[str]:
|
| 184 |
+
"""Currently active training games."""
|
| 185 |
+
return list(self._current_games)
|
| 186 |
+
|
| 187 |
+
@property
|
| 188 |
+
def config(self) -> GRPOConfig:
|
| 189 |
+
"""Training configuration."""
|
| 190 |
+
return self._config
|
train/kantbench_grpo_colab.ipynb
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": [],
|
| 7 |
+
"gpuType": "T4"
|
| 8 |
+
},
|
| 9 |
+
"kernelspec": {
|
| 10 |
+
"name": "python3",
|
| 11 |
+
"display_name": "Python 3"
|
| 12 |
+
},
|
| 13 |
+
"accelerator": "GPU"
|
| 14 |
+
},
|
| 15 |
+
"cells": [
|
| 16 |
+
{
|
| 17 |
+
"cell_type": "markdown",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"source": "# KantBench: GRPO Training on 90+ Game Theory Environments\n\nTrain a language model to play strategic games optimally using **Group Relative Policy Optimization (GRPO)** via HF TRL.\n\n**How it works:**\n- 90+ game theory environments (Prisoner's Dilemma, Cournot, Auctions, Signaling, ...)\n- 17 opponent strategies (tit-for-tat, grudger, adaptive, ...)\n- Each LLM completion is a **move** — the reward function plays a **full multi-round episode** using that move as the agent's strategy\n- Composite reward: payoff + cooperation rate + Pareto efficiency + fairness\n\n**Requirements:** Colab GPU runtime (T4 for 1.5B, A100 for 3B+)"
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "code",
|
| 23 |
+
"execution_count": null,
|
| 24 |
+
"metadata": {},
|
| 25 |
+
"outputs": [],
|
| 26 |
+
"source": "!pip install -q torch transformers trl datasets accelerate peft openenv-core>=0.2.1 wandb bitsandbytes nest_asyncio"
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "code",
|
| 30 |
+
"execution_count": null,
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"# Clone the repo to get the full game registry\n",
|
| 35 |
+
"!git clone --depth 1 https://github.com/wisent-ai/OpenEnv.git /content/OpenEnv\n",
|
| 36 |
+
"import sys\n",
|
| 37 |
+
"sys.path.insert(0, \"/content/OpenEnv\")"
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"cell_type": "code",
|
| 42 |
+
"execution_count": null,
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"outputs": [],
|
| 45 |
+
"source": [
|
| 46 |
+
"import wandb\n",
|
| 47 |
+
"wandb.login()"
|
| 48 |
+
]
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"cell_type": "markdown",
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"source": [
|
| 54 |
+
"## Config"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": null,
|
| 60 |
+
"metadata": {},
|
| 61 |
+
"outputs": [],
|
| 62 |
+
"source": "# --- Adjust these for your GPU ---\nMODEL = \"Qwen/Qwen2.5-1.5B-Instruct\" # 1.5B fits on T4; use 3B on A100\nNUM_EPISODES = 500\nNUM_GENERATIONS = 4\nBATCH_SIZE = 1\nGRAD_ACCUM = 8\nMAX_STEPS = 200\nLR = 5e-6"
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "markdown",
|
| 66 |
+
"metadata": {},
|
| 67 |
+
"source": "## Load Environment"
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"cell_type": "code",
|
| 71 |
+
"execution_count": null,
|
| 72 |
+
"metadata": {},
|
| 73 |
+
"outputs": [],
|
| 74 |
+
"source": "import random\nfrom common.games import GAMES\nfrom common.strategies import STRATEGIES as STRATEGY_REGISTRY\nfrom env.environment import KantEnvironment\nfrom env.models import GameAction, GameObservation\nfrom train.agent import PromptBuilder, parse_action\nfrom train.rewards import episode_reward\nfrom train.trajectory import _compute_cooperation_rate\n\nprint(f\"Loaded {len(GAMES)} games, {len(STRATEGY_REGISTRY)} strategies\")\nprint(f\"Sample games: {list(GAMES.keys())[:10]}\")"
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"cell_type": "markdown",
|
| 78 |
+
"metadata": {},
|
| 79 |
+
"source": "## Build Dataset with Real Environment States\n\nUses `PromptBuilder` for structured prompts and simulates partial game histories\nso the model trains on diverse game states (not just round 1)."
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "code",
|
| 83 |
+
"execution_count": null,
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"outputs": [],
|
| 86 |
+
"source": "from datasets import Dataset\n\nSYSTEM_PROMPT = (\n \"You are playing a game-theory game. Analyse the situation and choose \"\n \"the best action. Respond with ONLY the action name, nothing else.\"\n)\n\ndef build_dataset(n_samples):\n env = KantEnvironment()\n game_keys = list(GAMES.keys())\n strat_names = list(STRATEGY_REGISTRY.keys())\n prompt_builder = PromptBuilder()\n samples = []\n\n for _ in range(n_samples):\n game_key = random.choice(game_keys)\n strategy = random.choice(strat_names)\n\n obs = env.reset(game=game_key, strategy=strategy)\n\n # Play 0..N-1 random rounds for diverse game states\n rounds_to_play = random.randint(0, max(obs.total_rounds - 1, 0))\n for _ in range(rounds_to_play):\n random_action = GameAction(action=random.choice(obs.available_actions))\n obs = env.step(random_action)\n if obs.done:\n break\n\n if obs.done:\n obs = env.reset(game=game_key, strategy=strategy)\n\n prompt = prompt_builder.build(obs)\n samples.append({\n \"prompt\": prompt,\n \"game_key\": game_key,\n \"strategy\": strategy,\n \"available_moves\": list(obs.available_actions),\n })\n\n return Dataset.from_list(samples)\n\n\ndataset = build_dataset(NUM_EPISODES)\nprint(f\"Dataset: {len(dataset)} prompts\")\nprint(f\"\\nSample prompt:\\n{dataset[0]['prompt'][:500]}\")"
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"cell_type": "markdown",
|
| 90 |
+
"metadata": {},
|
| 91 |
+
"source": "## Reward Function: Full Episode Rollout\n\nFor each LLM completion:\n1. Parse the move\n2. Play a **full multi-round episode** using that move as the agent's strategy\n3. Compute composite reward: payoff + cooperation + Pareto efficiency + fairness"
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "code",
|
| 95 |
+
"execution_count": null,
|
| 96 |
+
"metadata": {},
|
| 97 |
+
"outputs": [],
|
| 98 |
+
"source": "from typing import Any\n\nreward_env = KantEnvironment()\n\ndef kantbench_reward(completions: list[str], prompts: list[str], **kwargs: Any) -> list[float]:\n rewards = []\n game_keys = kwargs.get(\"game_key\", [\"prisoners_dilemma\"] * len(completions))\n strategies = kwargs.get(\"strategy\", [\"tit_for_tat\"] * len(completions))\n available_moves_batch = kwargs.get(\"available_moves\", [[\"cooperate\", \"defect\"]] * len(completions))\n\n for completion, game_key, strategy, moves in zip(\n completions, game_keys, strategies, available_moves_batch\n ):\n action_str = parse_action(completion.strip(), moves)\n\n try:\n # Full episode rollout\n obs = reward_env.reset(game=game_key, strategy=strategy)\n while not obs.done:\n obs = reward_env.step(GameAction(action=action_str))\n\n coop_rate = _compute_cooperation_rate(obs)\n reward = episode_reward(\n player_score=obs.player_score,\n opponent_score=obs.opponent_score,\n cooperation_rate=coop_rate,\n total_rounds=obs.current_round,\n )\n rewards.append(reward)\n except Exception as e:\n rewards.append(-1.0)\n\n return rewards\n\n\n# Sanity check — cooperate vs defect in PD\nfor move in [\"cooperate\", \"defect\"]:\n r = kantbench_reward(\n [move], [\"...\"],\n game_key=[\"prisoners_dilemma\"],\n strategy=[\"tit_for_tat\"],\n available_moves=[[\"cooperate\", \"defect\"]],\n )\n print(f\"PD vs tit_for_tat | {move:10s} -> composite reward = {r[0]:.3f}\")"
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "markdown",
|
| 102 |
+
"metadata": {},
|
| 103 |
+
"source": [
|
| 104 |
+
"## Train with GRPO"
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"cell_type": "code",
|
| 109 |
+
"execution_count": null,
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"outputs": [],
|
| 112 |
+
"source": "import torch\nfrom transformers import AutoTokenizer\nfrom trl import GRPOConfig, GRPOTrainer\n\ntokenizer = AutoTokenizer.from_pretrained(MODEL)\nif tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n\ndef format_prompt(example):\n messages = [\n {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n {\"role\": \"user\", \"content\": example[\"prompt\"]},\n ]\n return {\"prompt\": tokenizer.apply_chat_template(\n messages, tokenize=False, add_generation_prompt=True\n )}\n\ntrain_dataset = dataset.map(format_prompt)\n\nconfig = GRPOConfig(\n output_dir=\"/content/kantbench-grpo\",\n num_generations=NUM_GENERATIONS,\n max_completion_length=16,\n per_device_train_batch_size=BATCH_SIZE,\n gradient_accumulation_steps=GRAD_ACCUM,\n learning_rate=LR,\n max_steps=MAX_STEPS,\n logging_steps=5,\n save_steps=50,\n bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8,\n fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,\n report_to=\"wandb\",\n)\n\ntrainer = GRPOTrainer(\n model=MODEL,\n reward_funcs=kantbench_reward,\n args=config,\n train_dataset=train_dataset,\n processing_class=tokenizer,\n)\n\nprint(f\"Training {MODEL} on {len(GAMES)} games with GRPO\")\nprint(f\"Reward: full-episode composite (payoff + cooperation + Pareto + fairness)\")\ntrainer.train()"
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"cell_type": "code",
|
| 116 |
+
"execution_count": null,
|
| 117 |
+
"metadata": {},
|
| 118 |
+
"outputs": [],
|
| 119 |
+
"source": [
|
| 120 |
+
"trainer.save_model(\"/content/kantbench-grpo\")\n",
|
| 121 |
+
"print(\"Model saved to /content/kantbench-grpo\")"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "markdown",
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"source": [
|
| 128 |
+
"## Evaluate: Before vs After"
|
| 129 |
+
]
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"cell_type": "code",
|
| 133 |
+
"execution_count": null,
|
| 134 |
+
"metadata": {},
|
| 135 |
+
"outputs": [],
|
| 136 |
+
"source": "from transformers import pipeline\n\ntest_games = [\"prisoners_dilemma\", \"stag_hunt\", \"hawk_dove\", \"cournot\", \"battle_of_the_sexes\"]\nprompt_builder = PromptBuilder()\neval_env = KantEnvironment()\n\npipe = pipeline(\"text-generation\", model=\"/content/kantbench-grpo\", tokenizer=tokenizer,\n max_new_tokens=8, do_sample=False)\n\nprint(\"=\" * 70)\nprint(f\"{'Game':<30s} {'Move':<15s} {'Episode Reward':>15s}\")\nprint(\"=\" * 70)\nfor game_key in test_games:\n obs = eval_env.reset(game=game_key, strategy=\"tit_for_tat\")\n prompt_text = prompt_builder.build(obs)\n messages = [\n {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n {\"role\": \"user\", \"content\": prompt_text},\n ]\n formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n output = pipe(formatted)[0][\"generated_text\"][len(formatted):].strip()\n move = parse_action(output, obs.available_actions)\n\n # Play full episode with this move\n obs = eval_env.reset(game=game_key, strategy=\"tit_for_tat\")\n while not obs.done:\n obs = eval_env.step(GameAction(action=move))\n coop = _compute_cooperation_rate(obs)\n r = episode_reward(obs.player_score, obs.opponent_score, coop, obs.current_round)\n\n game_name = GAMES[game_key].name\n print(f\"{game_name:<30s} {move:<15s} {r:>15.3f}\")"
|
| 137 |
+
}
|
| 138 |
+
]
|
| 139 |
+
}
|
train/nplayer/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""N-player and coalition LLM agents for game-theory environments."""
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"NPlayerLLMAgent",
|
| 5 |
+
"NPlayerPromptBuilder",
|
| 6 |
+
"CoalitionLLMAgent",
|
| 7 |
+
"CoalitionPromptBuilder",
|
| 8 |
+
]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def __getattr__(name: str) -> object:
|
| 12 |
+
"""Lazy imports to avoid pulling in heavy dependencies at load time."""
|
| 13 |
+
if name in ("NPlayerLLMAgent", "NPlayerPromptBuilder"):
|
| 14 |
+
from train.nplayer.nplayer_agent import (
|
| 15 |
+
NPlayerLLMAgent,
|
| 16 |
+
NPlayerPromptBuilder,
|
| 17 |
+
)
|
| 18 |
+
_map = {
|
| 19 |
+
"NPlayerLLMAgent": NPlayerLLMAgent,
|
| 20 |
+
"NPlayerPromptBuilder": NPlayerPromptBuilder,
|
| 21 |
+
}
|
| 22 |
+
return _map[name]
|
| 23 |
+
if name in ("CoalitionLLMAgent", "CoalitionPromptBuilder"):
|
| 24 |
+
from train.nplayer.coalition_agent import (
|
| 25 |
+
CoalitionLLMAgent,
|
| 26 |
+
CoalitionPromptBuilder,
|
| 27 |
+
)
|
| 28 |
+
_map = {
|
| 29 |
+
"CoalitionLLMAgent": CoalitionLLMAgent,
|
| 30 |
+
"CoalitionPromptBuilder": CoalitionPromptBuilder,
|
| 31 |
+
}
|
| 32 |
+
return _map[name]
|
| 33 |
+
msg = f"module 'train.nplayer' has no attribute {name!r}"
|
| 34 |
+
raise AttributeError(msg)
|
train/nplayer/coalition_agent.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM agent for coalition formation and meta-governance environments."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
from env.nplayer.coalition.models import (
|
| 8 |
+
CoalitionAction, CoalitionObservation,
|
| 9 |
+
CoalitionProposal, CoalitionResponse,
|
| 10 |
+
)
|
| 11 |
+
from env.nplayer.governance.models import GovernanceProposal, GovernanceVote
|
| 12 |
+
from env.nplayer.models import NPlayerAction
|
| 13 |
+
from train.agent import parse_action
|
| 14 |
+
from constant_definitions.train.agent_constants import (
|
| 15 |
+
COALITION_PROMPT_SECTION_COALITIONS,
|
| 16 |
+
COALITION_PROMPT_SECTION_PHASE,
|
| 17 |
+
COALITION_PROMPT_SECTION_PROPOSALS,
|
| 18 |
+
COALITION_SYSTEM_PROMPT,
|
| 19 |
+
GOVERNANCE_PROMPT_SECTION_PENDING,
|
| 20 |
+
GOVERNANCE_PROMPT_SECTION_RULES,
|
| 21 |
+
MAX_PROMPT_HISTORY_ROUNDS,
|
| 22 |
+
NPLAYER_PROMPT_SECTION_ALL_SCORES,
|
| 23 |
+
PROMPT_SECTION_ACTIONS, PROMPT_SECTION_GAME,
|
| 24 |
+
PROMPT_SECTION_HISTORY, PROMPT_SECTION_INSTRUCTION,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
_ZERO = int()
|
| 28 |
+
_ONE = int(bool(True))
|
| 29 |
+
_NL = "\n"
|
| 30 |
+
_SEP = "\n\n"
|
| 31 |
+
_BO = "["
|
| 32 |
+
_BC = "]"
|
| 33 |
+
_CS = ": "
|
| 34 |
+
_DS = "- "
|
| 35 |
+
_PP = "Player "
|
| 36 |
+
_RP = "Round "
|
| 37 |
+
_PS = " | "
|
| 38 |
+
_PL = " played: "
|
| 39 |
+
_PY = " payoff: "
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class CoalitionPromptBuilder:
|
| 43 |
+
"""Formats CoalitionObservation into structured text prompts."""
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def build_negotiate(obs: CoalitionObservation) -> str:
|
| 47 |
+
"""Build a negotiate-phase prompt."""
|
| 48 |
+
sections: List[str] = []
|
| 49 |
+
base = obs.base
|
| 50 |
+
sections.append(
|
| 51 |
+
_BO + PROMPT_SECTION_GAME + _BC + _NL
|
| 52 |
+
+ base.game_name + _NL + base.game_description
|
| 53 |
+
)
|
| 54 |
+
sections.append(
|
| 55 |
+
_BO + COALITION_PROMPT_SECTION_PHASE + _BC + _NL
|
| 56 |
+
+ obs.phase + _NL + "Enforcement" + _CS + obs.enforcement
|
| 57 |
+
)
|
| 58 |
+
if obs.pending_proposals:
|
| 59 |
+
prop_lines = [
|
| 60 |
+
str(idx) + _CS + "proposer=" + str(p.proposer)
|
| 61 |
+
+ " members=" + str(p.members)
|
| 62 |
+
+ " action=" + p.agreed_action
|
| 63 |
+
for idx, p in enumerate(obs.pending_proposals)
|
| 64 |
+
]
|
| 65 |
+
sections.append(
|
| 66 |
+
_BO + COALITION_PROMPT_SECTION_PROPOSALS + _BC
|
| 67 |
+
+ _NL + _NL.join(prop_lines)
|
| 68 |
+
)
|
| 69 |
+
if obs.active_coalitions:
|
| 70 |
+
coal_lines = [
|
| 71 |
+
"members=" + str(c.members) + " action=" + c.agreed_action
|
| 72 |
+
for c in obs.active_coalitions
|
| 73 |
+
]
|
| 74 |
+
sections.append(
|
| 75 |
+
_BO + COALITION_PROMPT_SECTION_COALITIONS + _BC
|
| 76 |
+
+ _NL + _NL.join(coal_lines)
|
| 77 |
+
)
|
| 78 |
+
if obs.current_rules is not None:
|
| 79 |
+
rules = obs.current_rules
|
| 80 |
+
active_mechs = [k for k, v in rules.mechanics.items() if v]
|
| 81 |
+
sections.append(
|
| 82 |
+
_BO + GOVERNANCE_PROMPT_SECTION_RULES + _BC + _NL
|
| 83 |
+
+ "enforcement" + _CS + rules.enforcement + _NL
|
| 84 |
+
+ "active_mechanics" + _CS + str(active_mechs)
|
| 85 |
+
)
|
| 86 |
+
if obs.pending_governance:
|
| 87 |
+
gov_lines = [
|
| 88 |
+
str(i) + _CS + gp.proposal_type + " by " + _PP + str(gp.proposer)
|
| 89 |
+
for i, gp in enumerate(obs.pending_governance)
|
| 90 |
+
]
|
| 91 |
+
sections.append(
|
| 92 |
+
_BO + GOVERNANCE_PROMPT_SECTION_PENDING + _BC
|
| 93 |
+
+ _NL + _NL.join(gov_lines)
|
| 94 |
+
)
|
| 95 |
+
score_lines = [
|
| 96 |
+
_PP + str(i) + _CS + str(s)
|
| 97 |
+
for i, s in enumerate(obs.adjusted_scores)
|
| 98 |
+
]
|
| 99 |
+
sections.append(
|
| 100 |
+
_BO + NPLAYER_PROMPT_SECTION_ALL_SCORES + _BC
|
| 101 |
+
+ _NL + _NL.join(score_lines)
|
| 102 |
+
)
|
| 103 |
+
action_lines = [_DS + a for a in base.available_actions]
|
| 104 |
+
sections.append(
|
| 105 |
+
_BO + PROMPT_SECTION_ACTIONS + _BC + _NL + _NL.join(action_lines)
|
| 106 |
+
)
|
| 107 |
+
sections.append(
|
| 108 |
+
_BO + PROMPT_SECTION_INSTRUCTION + _BC + _NL + COALITION_SYSTEM_PROMPT
|
| 109 |
+
)
|
| 110 |
+
return _SEP.join(sections)
|
| 111 |
+
|
| 112 |
+
@staticmethod
|
| 113 |
+
def build_action(obs: CoalitionObservation) -> str:
|
| 114 |
+
"""Build an action-phase prompt."""
|
| 115 |
+
sections: List[str] = []
|
| 116 |
+
base = obs.base
|
| 117 |
+
sections.append(
|
| 118 |
+
_BO + PROMPT_SECTION_GAME + _BC + _NL
|
| 119 |
+
+ base.game_name + _NL + base.game_description
|
| 120 |
+
)
|
| 121 |
+
sections.append(
|
| 122 |
+
_BO + COALITION_PROMPT_SECTION_PHASE + _BC + _NL + obs.phase
|
| 123 |
+
)
|
| 124 |
+
my_coals = [
|
| 125 |
+
"members=" + str(c.members) + " agreed_action=" + c.agreed_action
|
| 126 |
+
for c in obs.active_coalitions
|
| 127 |
+
if base.player_index in c.members
|
| 128 |
+
]
|
| 129 |
+
if my_coals:
|
| 130 |
+
sections.append(
|
| 131 |
+
_BO + COALITION_PROMPT_SECTION_COALITIONS + _BC
|
| 132 |
+
+ _NL + _NL.join(my_coals)
|
| 133 |
+
)
|
| 134 |
+
if base.history:
|
| 135 |
+
h_lines: List[str] = []
|
| 136 |
+
for rnd in base.history[-MAX_PROMPT_HISTORY_ROUNDS:]:
|
| 137 |
+
parts = [_RP + str(rnd.round_number)]
|
| 138 |
+
for pidx, (act, pay) in enumerate(zip(rnd.actions, rnd.payoffs)):
|
| 139 |
+
parts.append(
|
| 140 |
+
_PP + str(pidx) + _PL + act + _PY + str(pay)
|
| 141 |
+
)
|
| 142 |
+
h_lines.append(_PS.join(parts))
|
| 143 |
+
sections.append(
|
| 144 |
+
_BO + PROMPT_SECTION_HISTORY + _BC + _NL + _NL.join(h_lines)
|
| 145 |
+
)
|
| 146 |
+
action_lines = [_DS + a for a in base.available_actions]
|
| 147 |
+
sections.append(
|
| 148 |
+
_BO + PROMPT_SECTION_ACTIONS + _BC + _NL + _NL.join(action_lines)
|
| 149 |
+
)
|
| 150 |
+
sections.append(
|
| 151 |
+
_BO + PROMPT_SECTION_INSTRUCTION + _BC + _NL
|
| 152 |
+
+ "Choose your action. Respond with ONLY the action name."
|
| 153 |
+
)
|
| 154 |
+
return _SEP.join(sections)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _safe_json_parse(text: str) -> Optional[Dict[str, Any]]:
|
| 158 |
+
"""Try to parse JSON from LLM output, return None on failure."""
|
| 159 |
+
stripped = text.strip()
|
| 160 |
+
start = stripped.find("{")
|
| 161 |
+
end = stripped.rfind("}")
|
| 162 |
+
if start >= _ZERO and end > start:
|
| 163 |
+
try:
|
| 164 |
+
return json.loads(stripped[start:end + _ONE])
|
| 165 |
+
except (json.JSONDecodeError, ValueError):
|
| 166 |
+
pass
|
| 167 |
+
return None
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class CoalitionLLMAgent:
|
| 171 |
+
"""LLM-based agent for coalition environments.
|
| 172 |
+
|
| 173 |
+
Implements the negotiate + act protocol expected by
|
| 174 |
+
CoalitionTournamentRunner.
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
def __init__(
|
| 178 |
+
self, generate_fn: Callable[[str], str],
|
| 179 |
+
player_index: int = _ZERO,
|
| 180 |
+
prompt_builder: Optional[CoalitionPromptBuilder] = None,
|
| 181 |
+
) -> None:
|
| 182 |
+
self._generate_fn = generate_fn
|
| 183 |
+
self._player_index = player_index
|
| 184 |
+
self._prompt_builder = prompt_builder or CoalitionPromptBuilder()
|
| 185 |
+
|
| 186 |
+
def negotiate(self, obs: CoalitionObservation) -> CoalitionAction:
|
| 187 |
+
"""Generate coalition proposals and responses to pending ones."""
|
| 188 |
+
prompt = self._prompt_builder.build_negotiate(obs)
|
| 189 |
+
completion = self._generate_fn(prompt)
|
| 190 |
+
parsed = _safe_json_parse(completion)
|
| 191 |
+
if parsed is not None:
|
| 192 |
+
proposals = self._extract_proposals(parsed, obs)
|
| 193 |
+
responses = self._extract_responses(parsed, obs)
|
| 194 |
+
else:
|
| 195 |
+
proposals = []
|
| 196 |
+
responses = self._default_responses(obs)
|
| 197 |
+
return CoalitionAction(proposals=proposals, responses=responses)
|
| 198 |
+
|
| 199 |
+
def act(self, obs: CoalitionObservation) -> NPlayerAction:
|
| 200 |
+
"""Select a game action during the action phase."""
|
| 201 |
+
prompt = self._prompt_builder.build_action(obs)
|
| 202 |
+
completion = self._generate_fn(prompt)
|
| 203 |
+
action_str = parse_action(completion, obs.base.available_actions)
|
| 204 |
+
return NPlayerAction(action=action_str)
|
| 205 |
+
|
| 206 |
+
def _extract_proposals(
|
| 207 |
+
self, data: Dict[str, Any], obs: CoalitionObservation,
|
| 208 |
+
) -> List[CoalitionProposal]:
|
| 209 |
+
raw = data.get("proposals", [])
|
| 210 |
+
if not isinstance(raw, list):
|
| 211 |
+
return []
|
| 212 |
+
result: List[CoalitionProposal] = []
|
| 213 |
+
for item in raw:
|
| 214 |
+
if not isinstance(item, dict):
|
| 215 |
+
continue
|
| 216 |
+
members = item.get("members", [])
|
| 217 |
+
action = item.get("agreed_action", "")
|
| 218 |
+
if isinstance(members, list) and action in obs.base.available_actions:
|
| 219 |
+
result.append(CoalitionProposal(
|
| 220 |
+
proposer=self._player_index,
|
| 221 |
+
members=members, agreed_action=action,
|
| 222 |
+
))
|
| 223 |
+
return result
|
| 224 |
+
|
| 225 |
+
def _extract_responses(
|
| 226 |
+
self, data: Dict[str, Any], obs: CoalitionObservation,
|
| 227 |
+
) -> List[CoalitionResponse]:
|
| 228 |
+
raw = data.get("responses", {})
|
| 229 |
+
if not isinstance(raw, dict):
|
| 230 |
+
return self._default_responses(obs)
|
| 231 |
+
result: List[CoalitionResponse] = []
|
| 232 |
+
for idx in range(len(obs.pending_proposals)):
|
| 233 |
+
accepted = raw.get(str(idx), True)
|
| 234 |
+
result.append(CoalitionResponse(
|
| 235 |
+
responder=self._player_index,
|
| 236 |
+
proposal_index=idx, accepted=bool(accepted),
|
| 237 |
+
))
|
| 238 |
+
return result
|
| 239 |
+
|
| 240 |
+
def _default_responses(
|
| 241 |
+
self, obs: CoalitionObservation,
|
| 242 |
+
) -> List[CoalitionResponse]:
|
| 243 |
+
return [
|
| 244 |
+
CoalitionResponse(
|
| 245 |
+
responder=self._player_index,
|
| 246 |
+
proposal_index=idx, accepted=True,
|
| 247 |
+
)
|
| 248 |
+
for idx in range(len(obs.pending_proposals))
|
| 249 |
+
]
|
train/nplayer/nplayer_agent.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM agent for N-player game-theory environments."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Callable, List, Optional
|
| 6 |
+
|
| 7 |
+
from env.nplayer.models import NPlayerAction, NPlayerObservation
|
| 8 |
+
from train.agent import parse_action
|
| 9 |
+
from constant_definitions.train.agent_constants import (
|
| 10 |
+
MAX_PROMPT_HISTORY_ROUNDS,
|
| 11 |
+
NPLAYER_PROMPT_SECTION_ALL_SCORES,
|
| 12 |
+
NPLAYER_PROMPT_SECTION_PLAYERS,
|
| 13 |
+
NPLAYER_SYSTEM_PROMPT,
|
| 14 |
+
PROMPT_SECTION_ACTIONS,
|
| 15 |
+
PROMPT_SECTION_GAME,
|
| 16 |
+
PROMPT_SECTION_HISTORY,
|
| 17 |
+
PROMPT_SECTION_INSTRUCTION,
|
| 18 |
+
PROMPT_SECTION_SCORES,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
_ZERO = int()
|
| 22 |
+
_ONE = int(bool(True))
|
| 23 |
+
_NEWLINE = "\n"
|
| 24 |
+
_SECTION_SEP = "\n\n"
|
| 25 |
+
_BRACKET_OPEN = "["
|
| 26 |
+
_BRACKET_CLOSE = "]"
|
| 27 |
+
_COLON_SPACE = ": "
|
| 28 |
+
_DASH_SPACE = "- "
|
| 29 |
+
_ROUND_PREFIX = "Round "
|
| 30 |
+
_PIPE_SEP = " | "
|
| 31 |
+
_PLAYER_PREFIX = "Player "
|
| 32 |
+
_PLAYED = " played: "
|
| 33 |
+
_PAYOFF = " payoff: "
|
| 34 |
+
_YOUR_LABEL = "Your score"
|
| 35 |
+
_ROUND_LABEL = "Round"
|
| 36 |
+
_OF = " of "
|
| 37 |
+
_YOU_ARE = "You are Player "
|
| 38 |
+
_OUT_OF = " out of "
|
| 39 |
+
_PLAYERS = " players"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class NPlayerPromptBuilder:
|
| 43 |
+
"""Formats NPlayerObservation into a structured text prompt."""
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def build(obs: NPlayerObservation) -> str:
|
| 47 |
+
"""Build a structured prompt from an N-player observation."""
|
| 48 |
+
sections: List[str] = []
|
| 49 |
+
|
| 50 |
+
# Game section
|
| 51 |
+
sections.append(
|
| 52 |
+
_BRACKET_OPEN + PROMPT_SECTION_GAME + _BRACKET_CLOSE
|
| 53 |
+
+ _NEWLINE + obs.game_name
|
| 54 |
+
+ _NEWLINE + obs.game_description
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Players section
|
| 58 |
+
sections.append(
|
| 59 |
+
_BRACKET_OPEN + NPLAYER_PROMPT_SECTION_PLAYERS + _BRACKET_CLOSE
|
| 60 |
+
+ _NEWLINE + _YOU_ARE + str(obs.player_index)
|
| 61 |
+
+ _OUT_OF + str(obs.num_players) + _PLAYERS
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# History section
|
| 65 |
+
if obs.history:
|
| 66 |
+
history_lines: List[str] = []
|
| 67 |
+
history_slice = obs.history[-MAX_PROMPT_HISTORY_ROUNDS:]
|
| 68 |
+
for rnd in history_slice:
|
| 69 |
+
parts: List[str] = [_ROUND_PREFIX + str(rnd.round_number)]
|
| 70 |
+
for pidx, (act, pay) in enumerate(
|
| 71 |
+
zip(rnd.actions, rnd.payoffs),
|
| 72 |
+
):
|
| 73 |
+
parts.append(
|
| 74 |
+
_PLAYER_PREFIX + str(pidx)
|
| 75 |
+
+ _PLAYED + act
|
| 76 |
+
+ _PAYOFF + str(pay)
|
| 77 |
+
)
|
| 78 |
+
history_lines.append(_PIPE_SEP.join(parts))
|
| 79 |
+
sections.append(
|
| 80 |
+
_BRACKET_OPEN + PROMPT_SECTION_HISTORY + _BRACKET_CLOSE
|
| 81 |
+
+ _NEWLINE + _NEWLINE.join(history_lines)
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Scores section
|
| 85 |
+
score_lines: List[str] = []
|
| 86 |
+
for sidx, score in enumerate(obs.scores):
|
| 87 |
+
label = _PLAYER_PREFIX + str(sidx) + _COLON_SPACE + str(score)
|
| 88 |
+
score_lines.append(label)
|
| 89 |
+
sections.append(
|
| 90 |
+
_BRACKET_OPEN + NPLAYER_PROMPT_SECTION_ALL_SCORES + _BRACKET_CLOSE
|
| 91 |
+
+ _NEWLINE + _NEWLINE.join(score_lines)
|
| 92 |
+
+ _NEWLINE + _ROUND_LABEL + _COLON_SPACE + str(obs.current_round)
|
| 93 |
+
+ _OF + str(obs.total_rounds)
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Available actions
|
| 97 |
+
action_lines = [_DASH_SPACE + a for a in obs.available_actions]
|
| 98 |
+
sections.append(
|
| 99 |
+
_BRACKET_OPEN + PROMPT_SECTION_ACTIONS + _BRACKET_CLOSE
|
| 100 |
+
+ _NEWLINE + _NEWLINE.join(action_lines)
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Instruction
|
| 104 |
+
sections.append(
|
| 105 |
+
_BRACKET_OPEN + PROMPT_SECTION_INSTRUCTION + _BRACKET_CLOSE
|
| 106 |
+
+ _NEWLINE + NPLAYER_SYSTEM_PROMPT
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return _SECTION_SEP.join(sections)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class NPlayerLLMAgent:
|
| 113 |
+
"""LLM-based agent for N-player environments.
|
| 114 |
+
|
| 115 |
+
Compatible with NPlayerEnvironment.opponent_fns interface:
|
| 116 |
+
Callable[[NPlayerObservation], NPlayerAction].
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
generate_fn: Callable[[str], str],
|
| 122 |
+
prompt_builder: Optional[NPlayerPromptBuilder] = None,
|
| 123 |
+
) -> None:
|
| 124 |
+
self._generate_fn = generate_fn
|
| 125 |
+
self._prompt_builder = prompt_builder or NPlayerPromptBuilder()
|
| 126 |
+
self._last_prompt: str = ""
|
| 127 |
+
self._last_completion: str = ""
|
| 128 |
+
|
| 129 |
+
def __call__(self, obs: NPlayerObservation) -> NPlayerAction:
|
| 130 |
+
"""Select an action given an N-player observation."""
|
| 131 |
+
prompt = self._prompt_builder.build(obs)
|
| 132 |
+
self._last_prompt = prompt
|
| 133 |
+
completion = self._generate_fn(prompt)
|
| 134 |
+
self._last_completion = completion
|
| 135 |
+
action_str = parse_action(completion, obs.available_actions)
|
| 136 |
+
return NPlayerAction(action=action_str)
|
| 137 |
+
|
| 138 |
+
@property
|
| 139 |
+
def last_prompt(self) -> str:
|
| 140 |
+
"""The most recently constructed prompt."""
|
| 141 |
+
return self._last_prompt
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def last_completion(self) -> str:
|
| 145 |
+
"""The most recent raw model completion."""
|
| 146 |
+
return self._last_completion
|
train/requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.4.0
|
| 2 |
+
transformers>=4.47.0
|
| 3 |
+
trl>=0.12.0
|
| 4 |
+
datasets>=3.0.0
|
| 5 |
+
accelerate>=1.0.0
|
| 6 |
+
peft>=0.13.0
|
| 7 |
+
openenv-core>=0.2.0
|
| 8 |
+
huggingface_hub>=0.26.0
|
| 9 |
+
bitsandbytes>=0.44.0
|
train/rewards.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reward functions for the training pipeline."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
from constant_definitions.game_constants import (
|
| 8 |
+
EVAL_HALF,
|
| 9 |
+
EVAL_ONE,
|
| 10 |
+
EVAL_ONE_FLOAT,
|
| 11 |
+
EVAL_TWO,
|
| 12 |
+
EVAL_ZERO,
|
| 13 |
+
EVAL_ZERO_FLOAT,
|
| 14 |
+
)
|
| 15 |
+
from constant_definitions.train.grpo_constants import (
|
| 16 |
+
GRPO_SHAPING_ALPHA_DENOMINATOR,
|
| 17 |
+
GRPO_SHAPING_ALPHA_NUMERATOR,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
_FIVE = EVAL_TWO + EVAL_TWO + EVAL_ONE
|
| 21 |
+
|
| 22 |
+
# Default weight per sub-metric (equal weighting across five metrics).
|
| 23 |
+
_DEFAULT_WEIGHT_NUMERATOR = EVAL_ONE
|
| 24 |
+
_DEFAULT_WEIGHT_DENOMINATOR = _FIVE
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _default_weights() -> Dict[str, float]:
|
| 28 |
+
"""Return default equal weights for the five reward components."""
|
| 29 |
+
w = _DEFAULT_WEIGHT_NUMERATOR / _DEFAULT_WEIGHT_DENOMINATOR
|
| 30 |
+
return {
|
| 31 |
+
"cooperation_rate": w,
|
| 32 |
+
"pareto_efficiency": w,
|
| 33 |
+
"fairness_index": w,
|
| 34 |
+
"exploitation_resistance": w,
|
| 35 |
+
"adaptability": w,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
# Per-episode reward
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def episode_reward(
|
| 45 |
+
player_score: float,
|
| 46 |
+
opponent_score: float,
|
| 47 |
+
cooperation_rate: float,
|
| 48 |
+
total_rounds: int,
|
| 49 |
+
weights: Optional[Dict[str, float]] = None,
|
| 50 |
+
) -> float:
|
| 51 |
+
"""Compute a scalar reward for a single episode.
|
| 52 |
+
|
| 53 |
+
Uses per-episode metrics that can be computed without cross-strategy data:
|
| 54 |
+
cooperation_rate, pareto_efficiency proxy, and fairness_index.
|
| 55 |
+
|
| 56 |
+
Exploitation_resistance and adaptability default to neutral since they
|
| 57 |
+
require cross-strategy comparison (see ``batch_reward``).
|
| 58 |
+
"""
|
| 59 |
+
w = weights if weights is not None else _default_weights()
|
| 60 |
+
|
| 61 |
+
# Cooperation rate: direct
|
| 62 |
+
coop = cooperation_rate
|
| 63 |
+
|
| 64 |
+
# Pareto efficiency proxy: normalised joint score
|
| 65 |
+
joint = player_score + opponent_score
|
| 66 |
+
if total_rounds > EVAL_ZERO:
|
| 67 |
+
pareto_proxy = joint / total_rounds
|
| 68 |
+
# Clamp to [zero, one]
|
| 69 |
+
pareto_proxy = max(EVAL_ZERO_FLOAT, min(EVAL_ONE_FLOAT, pareto_proxy))
|
| 70 |
+
else:
|
| 71 |
+
pareto_proxy = EVAL_ZERO_FLOAT
|
| 72 |
+
|
| 73 |
+
# Fairness: EVAL_ONE_FLOAT - |p - o| / (|p| + |o|)
|
| 74 |
+
denom = abs(player_score) + abs(opponent_score)
|
| 75 |
+
if denom > EVAL_ZERO_FLOAT:
|
| 76 |
+
fairness = EVAL_ONE_FLOAT - abs(player_score - opponent_score) / denom
|
| 77 |
+
else:
|
| 78 |
+
fairness = EVAL_ONE_FLOAT
|
| 79 |
+
|
| 80 |
+
# Cross-strategy metrics default to neutral midpoint
|
| 81 |
+
exploit_resist = EVAL_HALF
|
| 82 |
+
adapt = EVAL_HALF
|
| 83 |
+
|
| 84 |
+
reward = (
|
| 85 |
+
w["cooperation_rate"] * coop
|
| 86 |
+
+ w["pareto_efficiency"] * pareto_proxy
|
| 87 |
+
+ w["fairness_index"] * fairness
|
| 88 |
+
+ w["exploitation_resistance"] * exploit_resist
|
| 89 |
+
+ w["adaptability"] * adapt
|
| 90 |
+
)
|
| 91 |
+
return reward
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# ---------------------------------------------------------------------------
|
| 95 |
+
# Batch reward (cross-strategy)
|
| 96 |
+
# ---------------------------------------------------------------------------
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def batch_reward(
|
| 100 |
+
episode_results: List[Dict[str, Any]],
|
| 101 |
+
weights: Optional[Dict[str, float]] = None,
|
| 102 |
+
) -> Dict[str, float]:
|
| 103 |
+
"""Compute cross-strategy reward metrics over a batch of episodes.
|
| 104 |
+
|
| 105 |
+
Parameters
|
| 106 |
+
----------
|
| 107 |
+
episode_results : list of dict
|
| 108 |
+
Each dict must have keys: ``game``, ``strategy``,
|
| 109 |
+
``player_score``, ``opponent_score``, ``cooperation_rate``.
|
| 110 |
+
|
| 111 |
+
Returns
|
| 112 |
+
-------
|
| 113 |
+
dict
|
| 114 |
+
Mapping of metric name to value for exploitation_resistance
|
| 115 |
+
and adaptability computed across strategies for each game.
|
| 116 |
+
"""
|
| 117 |
+
w = weights if weights is not None else _default_weights()
|
| 118 |
+
|
| 119 |
+
# Group by game
|
| 120 |
+
by_game: Dict[str, List[Dict[str, Any]]] = {}
|
| 121 |
+
for ep in episode_results:
|
| 122 |
+
game = ep["game"]
|
| 123 |
+
if game not in by_game:
|
| 124 |
+
by_game[game] = []
|
| 125 |
+
by_game[game].append(ep)
|
| 126 |
+
|
| 127 |
+
exploit_scores: List[float] = []
|
| 128 |
+
adapt_scores: List[float] = []
|
| 129 |
+
|
| 130 |
+
for _game, episodes in by_game.items():
|
| 131 |
+
# Group by strategy within game
|
| 132 |
+
by_strat: Dict[str, List[Dict[str, Any]]] = {}
|
| 133 |
+
for ep in episodes:
|
| 134 |
+
strat = ep["strategy"]
|
| 135 |
+
if strat not in by_strat:
|
| 136 |
+
by_strat[strat] = []
|
| 137 |
+
by_strat[strat].append(ep)
|
| 138 |
+
|
| 139 |
+
if len(by_strat) <= EVAL_ONE:
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
# Exploitation resistance: performance against always_defect
|
| 143 |
+
# relative to best/worst across strategies
|
| 144 |
+
strat_scores = {
|
| 145 |
+
s: sum(e["player_score"] for e in eps)
|
| 146 |
+
for s, eps in by_strat.items()
|
| 147 |
+
}
|
| 148 |
+
best = max(strat_scores.values())
|
| 149 |
+
worst = min(strat_scores.values())
|
| 150 |
+
spread = best - worst
|
| 151 |
+
if "always_defect" in strat_scores and spread > EVAL_ZERO_FLOAT:
|
| 152 |
+
ad_score = strat_scores["always_defect"]
|
| 153 |
+
exploit_scores.append((ad_score - worst) / spread)
|
| 154 |
+
|
| 155 |
+
# Adaptability: variance of cooperation rates across strategies
|
| 156 |
+
coop_rates = []
|
| 157 |
+
for eps in by_strat.values():
|
| 158 |
+
rate_sum = sum(e["cooperation_rate"] for e in eps)
|
| 159 |
+
coop_rates.append(rate_sum / len(eps))
|
| 160 |
+
|
| 161 |
+
if len(coop_rates) > EVAL_ONE:
|
| 162 |
+
mean_coop = sum(coop_rates) / len(coop_rates)
|
| 163 |
+
var = sum(
|
| 164 |
+
(r - mean_coop) ** EVAL_TWO for r in coop_rates
|
| 165 |
+
) / len(coop_rates)
|
| 166 |
+
capped = min(var, EVAL_HALF)
|
| 167 |
+
adapt_scores.append(capped / EVAL_HALF)
|
| 168 |
+
|
| 169 |
+
exploit_val = (
|
| 170 |
+
sum(exploit_scores) / len(exploit_scores)
|
| 171 |
+
if exploit_scores else EVAL_HALF
|
| 172 |
+
)
|
| 173 |
+
adapt_val = (
|
| 174 |
+
sum(adapt_scores) / len(adapt_scores)
|
| 175 |
+
if adapt_scores else EVAL_ZERO_FLOAT
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
return {
|
| 179 |
+
"exploitation_resistance": exploit_val,
|
| 180 |
+
"adaptability": adapt_val,
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# ---------------------------------------------------------------------------
|
| 185 |
+
# Per-step shaping
|
| 186 |
+
# ---------------------------------------------------------------------------
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def per_step_shaping(
|
| 190 |
+
player_payoff: float,
|
| 191 |
+
opponent_payoff: float,
|
| 192 |
+
payoff_min: float,
|
| 193 |
+
payoff_max: float,
|
| 194 |
+
) -> float:
|
| 195 |
+
"""Optional per-step reward shaping based on immediate payoffs.
|
| 196 |
+
|
| 197 |
+
Returns a small bonus proportional to normalised joint payoff,
|
| 198 |
+
scaled by the shaping coefficient alpha.
|
| 199 |
+
"""
|
| 200 |
+
alpha = GRPO_SHAPING_ALPHA_NUMERATOR / GRPO_SHAPING_ALPHA_DENOMINATOR
|
| 201 |
+
payoff_range = payoff_max - payoff_min
|
| 202 |
+
if payoff_range <= EVAL_ZERO_FLOAT:
|
| 203 |
+
return EVAL_ZERO_FLOAT
|
| 204 |
+
joint = player_payoff + opponent_payoff
|
| 205 |
+
normalised = (joint - payoff_min * EVAL_TWO) / (payoff_range * EVAL_TWO)
|
| 206 |
+
return alpha * normalised
|
train/self_play/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Self-play multi-agent training infrastructure."""
|
train/self_play/config.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration for self-play GRPO training."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
from constant_definitions.train.grpo_constants import (
|
| 8 |
+
GRPO_BATCH_SIZE,
|
| 9 |
+
GRPO_LR_DENOMINATOR,
|
| 10 |
+
GRPO_LR_NUMERATOR,
|
| 11 |
+
GRPO_MAX_COMPLETION_LENGTH,
|
| 12 |
+
GRPO_NUM_GENERATIONS,
|
| 13 |
+
)
|
| 14 |
+
from constant_definitions.var.meta.self_play_constants import (
|
| 15 |
+
SELF_PLAY_DEFAULT_EPISODES_PER_STEP,
|
| 16 |
+
SELF_PLAY_DEFAULT_MAX_STEPS,
|
| 17 |
+
SELF_PLAY_OPPONENT_UPDATE_INTERVAL,
|
| 18 |
+
SELF_PLAY_POOL_MAX_SIZE,
|
| 19 |
+
SELF_PLAY_WARMUP_EPISODES,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class SelfPlayConfig:
|
| 25 |
+
"""Configuration for self-play GRPO training.
|
| 26 |
+
|
| 27 |
+
Combines self-play-specific settings (opponent pool management,
|
| 28 |
+
update frequency) with standard GRPO training parameters.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
# Model
|
| 32 |
+
model_name: str = "Qwen/Qwen2.5-3B-Instruct"
|
| 33 |
+
output_dir: str = "./kantbench-self-play"
|
| 34 |
+
|
| 35 |
+
# Self-play specific
|
| 36 |
+
opponent_update_interval: int = SELF_PLAY_OPPONENT_UPDATE_INTERVAL
|
| 37 |
+
pool_max_size: int = SELF_PLAY_POOL_MAX_SIZE
|
| 38 |
+
episodes_per_step: int = SELF_PLAY_DEFAULT_EPISODES_PER_STEP
|
| 39 |
+
warmup_episodes: int = SELF_PLAY_WARMUP_EPISODES
|
| 40 |
+
|
| 41 |
+
# GRPO params
|
| 42 |
+
learning_rate_numerator: int = GRPO_LR_NUMERATOR
|
| 43 |
+
learning_rate_denominator: int = GRPO_LR_DENOMINATOR
|
| 44 |
+
batch_size: int = GRPO_BATCH_SIZE
|
| 45 |
+
num_generations: int = GRPO_NUM_GENERATIONS
|
| 46 |
+
max_completion_length: int = GRPO_MAX_COMPLETION_LENGTH
|
| 47 |
+
max_steps: int = SELF_PLAY_DEFAULT_MAX_STEPS
|
| 48 |
+
|
| 49 |
+
# Cross-model mode: if set, opponent is loaded from this path
|
| 50 |
+
cross_model_path: str = ""
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def learning_rate(self) -> float:
|
| 54 |
+
"""Compute learning rate from numerator/denominator."""
|
| 55 |
+
return self.learning_rate_numerator / self.learning_rate_denominator
|
train/self_play/oauth.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OAuth token management for Anthropic and OpenAI self-play integration."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import base64
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import httpx
|
| 11 |
+
|
| 12 |
+
from constant_definitions.var.meta.self_play_constants import (
|
| 13 |
+
ANTHROPIC_OAUTH_TOKEN_URL,
|
| 14 |
+
ANTHROPIC_OAUTH_CLIENT_ID,
|
| 15 |
+
OPENAI_OAUTH_TOKEN_URL,
|
| 16 |
+
OPENAI_OAUTH_CLIENT_ID,
|
| 17 |
+
SUPABASE_OAUTH_TABLE,
|
| 18 |
+
SUPABASE_PROVIDER_ANTHROPIC,
|
| 19 |
+
SUPABASE_PROVIDER_OPENAI,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
_ZERO = int()
|
| 23 |
+
_ONE = int(bool(True))
|
| 24 |
+
_CONTENT_TYPE_FORM = "application/x-www-form-urlencoded"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _read_env_file() -> dict[str, str]:
|
| 28 |
+
"""Read content-platform .env.local into a dict."""
|
| 29 |
+
env_path = os.path.join(
|
| 30 |
+
os.path.expanduser("~"),
|
| 31 |
+
"Documents", "CodingProjects", "Wisent",
|
| 32 |
+
"content-platform", ".env.local",
|
| 33 |
+
)
|
| 34 |
+
env_vars: dict[str, str] = {}
|
| 35 |
+
with open(env_path) as fh:
|
| 36 |
+
for line in fh:
|
| 37 |
+
if "=" in line and not line.startswith("#"):
|
| 38 |
+
key, val = line.split("=", _ONE)
|
| 39 |
+
env_vars[key] = (
|
| 40 |
+
val.strip().strip('"').replace("\\n", "").strip()
|
| 41 |
+
)
|
| 42 |
+
return env_vars
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _supabase_headers(service_key: str) -> dict[str, str]:
|
| 46 |
+
"""Return Supabase REST API headers."""
|
| 47 |
+
return {
|
| 48 |
+
"apikey": service_key,
|
| 49 |
+
"Authorization": "Bearer " + service_key,
|
| 50 |
+
"Content-Type": "application/json",
|
| 51 |
+
"Prefer": "return=minimal",
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def fetch_refresh_token(
|
| 56 |
+
provider: str,
|
| 57 |
+
supabase_url: str = "",
|
| 58 |
+
service_key: str = "",
|
| 59 |
+
) -> Tuple[str, str]:
|
| 60 |
+
"""Fetch the first refresh token for *provider* from Supabase.
|
| 61 |
+
|
| 62 |
+
Returns (credential_id, refresh_token).
|
| 63 |
+
"""
|
| 64 |
+
if not supabase_url or not service_key:
|
| 65 |
+
env = _read_env_file()
|
| 66 |
+
supabase_url = supabase_url or env["NEXT_PUBLIC_SUPABASE_URL"]
|
| 67 |
+
service_key = service_key or env["SUPABASE_SERVICE_ROLE_KEY"]
|
| 68 |
+
resp = httpx.get(
|
| 69 |
+
supabase_url + "/rest/v" + str(_ONE) + "/" + SUPABASE_OAUTH_TABLE,
|
| 70 |
+
params={"provider": "eq." + provider, "select": "*"},
|
| 71 |
+
headers=_supabase_headers(service_key),
|
| 72 |
+
)
|
| 73 |
+
rows = resp.json()
|
| 74 |
+
if not rows:
|
| 75 |
+
raise RuntimeError(f"No {provider} credentials in Supabase")
|
| 76 |
+
row = rows[_ZERO]
|
| 77 |
+
return row["id"], row["refresh_token"]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def save_refresh_token(
|
| 81 |
+
credential_id: str,
|
| 82 |
+
new_refresh_token: str,
|
| 83 |
+
access_token: str = "",
|
| 84 |
+
supabase_url: str = "",
|
| 85 |
+
service_key: str = "",
|
| 86 |
+
) -> None:
|
| 87 |
+
"""Save a rotated refresh token back to Supabase."""
|
| 88 |
+
if not supabase_url or not service_key:
|
| 89 |
+
env = _read_env_file()
|
| 90 |
+
supabase_url = supabase_url or env["NEXT_PUBLIC_SUPABASE_URL"]
|
| 91 |
+
service_key = service_key or env["SUPABASE_SERVICE_ROLE_KEY"]
|
| 92 |
+
body: dict[str, str] = {"refresh_token": new_refresh_token}
|
| 93 |
+
if access_token:
|
| 94 |
+
body["access_token"] = access_token
|
| 95 |
+
httpx.patch(
|
| 96 |
+
supabase_url + "/rest/v" + str(_ONE) + "/" + SUPABASE_OAUTH_TABLE,
|
| 97 |
+
params={"id": "eq." + credential_id},
|
| 98 |
+
json=body,
|
| 99 |
+
headers=_supabase_headers(service_key),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def exchange_anthropic(
|
| 104 |
+
refresh_token: str,
|
| 105 |
+
) -> Tuple[str, str]:
|
| 106 |
+
"""Exchange Anthropic refresh token. Returns (access, new_refresh)."""
|
| 107 |
+
resp = httpx.post(
|
| 108 |
+
ANTHROPIC_OAUTH_TOKEN_URL,
|
| 109 |
+
data={
|
| 110 |
+
"grant_type": "refresh_token",
|
| 111 |
+
"refresh_token": refresh_token,
|
| 112 |
+
"client_id": ANTHROPIC_OAUTH_CLIENT_ID,
|
| 113 |
+
},
|
| 114 |
+
headers={"Content-Type": _CONTENT_TYPE_FORM},
|
| 115 |
+
)
|
| 116 |
+
resp.raise_for_status()
|
| 117 |
+
data = resp.json()
|
| 118 |
+
return data["access_token"], data.get("refresh_token", "")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def exchange_openai(
|
| 122 |
+
refresh_token: str,
|
| 123 |
+
) -> Tuple[str, str, str]:
|
| 124 |
+
"""Exchange OpenAI refresh token. Returns (access, new_refresh, account_id)."""
|
| 125 |
+
resp = httpx.post(
|
| 126 |
+
OPENAI_OAUTH_TOKEN_URL,
|
| 127 |
+
data={
|
| 128 |
+
"grant_type": "refresh_token",
|
| 129 |
+
"refresh_token": refresh_token,
|
| 130 |
+
"client_id": OPENAI_OAUTH_CLIENT_ID,
|
| 131 |
+
},
|
| 132 |
+
headers={"Content-Type": _CONTENT_TYPE_FORM},
|
| 133 |
+
)
|
| 134 |
+
resp.raise_for_status()
|
| 135 |
+
data = resp.json()
|
| 136 |
+
access = data["access_token"]
|
| 137 |
+
new_rt = data.get("refresh_token", "")
|
| 138 |
+
account_id = _extract_account_id(data.get("id_token", ""))
|
| 139 |
+
return access, new_rt, account_id
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _extract_account_id(id_token: str) -> str:
|
| 143 |
+
"""Extract chatgpt_account_id from an OpenAI id_token JWT."""
|
| 144 |
+
if not id_token:
|
| 145 |
+
return ""
|
| 146 |
+
parts = id_token.split(".")
|
| 147 |
+
if len(parts) < _ONE + _ONE:
|
| 148 |
+
return ""
|
| 149 |
+
payload = parts[_ONE]
|
| 150 |
+
# Pad base64
|
| 151 |
+
padding = (_ONE + _ONE + _ONE + _ONE) - len(payload) % (
|
| 152 |
+
_ONE + _ONE + _ONE + _ONE
|
| 153 |
+
)
|
| 154 |
+
if padding < (_ONE + _ONE + _ONE + _ONE):
|
| 155 |
+
payload += "=" * padding
|
| 156 |
+
decoded = json.loads(base64.urlsafe_b64decode(payload))
|
| 157 |
+
claims = decoded.get("https://api.openai.com/auth", {})
|
| 158 |
+
return claims.get("chatgpt_account_id", "")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_anthropic_access_token() -> str:
|
| 162 |
+
"""Full flow: try all Supabase credentials until one works."""
|
| 163 |
+
env = _read_env_file()
|
| 164 |
+
sb_url = env["NEXT_PUBLIC_SUPABASE_URL"]
|
| 165 |
+
sb_key = env["SUPABASE_SERVICE_ROLE_KEY"]
|
| 166 |
+
resp = httpx.get(
|
| 167 |
+
sb_url + "/rest/v" + str(_ONE) + "/" + SUPABASE_OAUTH_TABLE,
|
| 168 |
+
params={"provider": "eq." + SUPABASE_PROVIDER_ANTHROPIC, "select": "*"},
|
| 169 |
+
headers=_supabase_headers(sb_key),
|
| 170 |
+
)
|
| 171 |
+
rows = resp.json()
|
| 172 |
+
last_err: Exception = RuntimeError("No credentials found")
|
| 173 |
+
for row in rows:
|
| 174 |
+
cred_id, rt = row["id"], row["refresh_token"]
|
| 175 |
+
try:
|
| 176 |
+
access, new_rt = exchange_anthropic(rt)
|
| 177 |
+
if new_rt:
|
| 178 |
+
save_refresh_token(cred_id, new_rt, access, sb_url, sb_key)
|
| 179 |
+
return access
|
| 180 |
+
except Exception as exc:
|
| 181 |
+
last_err = exc
|
| 182 |
+
raise last_err
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def get_openai_credentials() -> Tuple[str, str]:
|
| 186 |
+
"""Full flow: returns (access_token, account_id)."""
|
| 187 |
+
cred_id, rt = fetch_refresh_token(SUPABASE_PROVIDER_OPENAI)
|
| 188 |
+
access, new_rt, account_id = exchange_openai(rt)
|
| 189 |
+
if new_rt:
|
| 190 |
+
save_refresh_token(cred_id, new_rt, access)
|
| 191 |
+
return access, account_id
|
train/self_play/opponents.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Frozen opponents and opponent pool for self-play training."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
from typing import Callable, List, Optional
|
| 7 |
+
|
| 8 |
+
from env.models import GameAction, GameObservation
|
| 9 |
+
from train.agent import PromptBuilder, parse_action
|
| 10 |
+
from constant_definitions.train.agent_constants import (
|
| 11 |
+
MAX_ACTION_TOKENS,
|
| 12 |
+
SYSTEM_PROMPT,
|
| 13 |
+
)
|
| 14 |
+
from constant_definitions.var.meta.self_play_constants import (
|
| 15 |
+
SELF_PLAY_POOL_MAX_SIZE,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
_ZERO = int()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class FrozenOpponent:
|
| 22 |
+
"""Wraps a generation function for use as opponent_fn in KantEnvironment.
|
| 23 |
+
|
| 24 |
+
Runs inference with no gradients. Compatible with the
|
| 25 |
+
``opponent_fn: Callable[[GameObservation], GameAction]`` interface
|
| 26 |
+
that KantEnvironment.reset() accepts.
|
| 27 |
+
|
| 28 |
+
Parameters
|
| 29 |
+
----------
|
| 30 |
+
generate_fn : callable
|
| 31 |
+
A function ``(prompt: str) -> str`` that produces a completion.
|
| 32 |
+
prompt_builder : PromptBuilder, optional
|
| 33 |
+
Custom prompt builder. Defaults to the standard PromptBuilder.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
generate_fn: Callable[[str], str],
|
| 39 |
+
prompt_builder: Optional[PromptBuilder] = None,
|
| 40 |
+
) -> None:
|
| 41 |
+
self._generate_fn = generate_fn
|
| 42 |
+
self._builder = prompt_builder or PromptBuilder()
|
| 43 |
+
|
| 44 |
+
def __call__(self, obs: GameObservation) -> GameAction:
|
| 45 |
+
"""Select an action given a game observation."""
|
| 46 |
+
prompt = self._builder.build(obs)
|
| 47 |
+
completion = self._generate_fn(prompt)
|
| 48 |
+
action_str = parse_action(completion, obs.available_actions)
|
| 49 |
+
return GameAction(action=action_str)
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
def from_model(
|
| 53 |
+
cls,
|
| 54 |
+
model: object,
|
| 55 |
+
tokenizer: object,
|
| 56 |
+
max_tokens: int = MAX_ACTION_TOKENS,
|
| 57 |
+
) -> FrozenOpponent:
|
| 58 |
+
"""Create from a HuggingFace model (runs with torch.no_grad)."""
|
| 59 |
+
import torch
|
| 60 |
+
|
| 61 |
+
def _generate(prompt: str) -> str:
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 64 |
+
input_len = len(inputs["input_ids"][_ZERO])
|
| 65 |
+
outputs = model.generate(
|
| 66 |
+
**inputs, max_new_tokens=max_tokens,
|
| 67 |
+
)
|
| 68 |
+
return tokenizer.decode(
|
| 69 |
+
outputs[_ZERO][input_len:],
|
| 70 |
+
skip_special_tokens=True,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
return cls(generate_fn=_generate)
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def from_checkpoint(
|
| 77 |
+
cls,
|
| 78 |
+
path: str,
|
| 79 |
+
tokenizer_name: str,
|
| 80 |
+
max_tokens: int = MAX_ACTION_TOKENS,
|
| 81 |
+
) -> FrozenOpponent:
|
| 82 |
+
"""Load a frozen opponent from a saved checkpoint directory."""
|
| 83 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 84 |
+
|
| 85 |
+
loaded_model = AutoModelForCausalLM.from_pretrained(path)
|
| 86 |
+
loaded_model.eval()
|
| 87 |
+
loaded_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 88 |
+
return cls.from_model(loaded_model, loaded_tokenizer, max_tokens)
|
| 89 |
+
|
| 90 |
+
@classmethod
|
| 91 |
+
def from_api(
|
| 92 |
+
cls,
|
| 93 |
+
api_call_fn: Callable[[str, str], str],
|
| 94 |
+
) -> FrozenOpponent:
|
| 95 |
+
"""Create from an API-based agent (OpenAI, Anthropic, etc.)."""
|
| 96 |
+
return cls(
|
| 97 |
+
generate_fn=lambda prompt: api_call_fn(SYSTEM_PROMPT, prompt),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class OpponentPool:
|
| 102 |
+
"""Maintains a pool of past model checkpoints as diverse opponents.
|
| 103 |
+
|
| 104 |
+
Samples uniformly from the pool for opponent diversity.
|
| 105 |
+
Evicts the oldest entry when the pool exceeds ``max_size``.
|
| 106 |
+
|
| 107 |
+
Parameters
|
| 108 |
+
----------
|
| 109 |
+
max_size : int
|
| 110 |
+
Maximum number of frozen opponents to keep in the pool.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(self, max_size: int = SELF_PLAY_POOL_MAX_SIZE) -> None:
|
| 114 |
+
self._pool: List[FrozenOpponent] = []
|
| 115 |
+
self._max_size = max_size
|
| 116 |
+
|
| 117 |
+
def add(self, opponent: FrozenOpponent) -> None:
|
| 118 |
+
"""Add a frozen opponent to the pool, evicting oldest if full."""
|
| 119 |
+
self._pool.append(opponent)
|
| 120 |
+
if len(self._pool) > self._max_size:
|
| 121 |
+
self._pool.pop(_ZERO)
|
| 122 |
+
|
| 123 |
+
def sample(self) -> FrozenOpponent:
|
| 124 |
+
"""Return a randomly chosen opponent from the pool.
|
| 125 |
+
|
| 126 |
+
Raises
|
| 127 |
+
------
|
| 128 |
+
IndexError
|
| 129 |
+
If the pool is empty.
|
| 130 |
+
"""
|
| 131 |
+
if not self._pool:
|
| 132 |
+
raise IndexError("Cannot sample from an empty opponent pool.")
|
| 133 |
+
return random.choice(self._pool)
|
| 134 |
+
|
| 135 |
+
def get_opponent_fn(self) -> Callable[[GameObservation], GameAction]:
|
| 136 |
+
"""Return a callable that uses a sampled opponent."""
|
| 137 |
+
return self.sample()
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def size(self) -> int:
|
| 141 |
+
"""Current number of opponents in the pool."""
|
| 142 |
+
return len(self._pool)
|
train/self_play/trainer.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Self-play GRPO trainer for multi-agent training."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
import logging
|
| 7 |
+
import random
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
from env.environment import KantEnvironment
|
| 11 |
+
from env.models import GameAction, GameObservation
|
| 12 |
+
from train.agent import LLMAgent, PromptBuilder, parse_action
|
| 13 |
+
from train.rewards import episode_reward
|
| 14 |
+
from train.trajectory import TrajectoryCollector, EpisodeTrajectory
|
| 15 |
+
from train.self_play.opponents import FrozenOpponent, OpponentPool
|
| 16 |
+
from train.self_play.config import SelfPlayConfig
|
| 17 |
+
from constant_definitions.train.agent_constants import SYSTEM_PROMPT
|
| 18 |
+
from constant_definitions.train.grpo_constants import GRPO_LOG_EVERY
|
| 19 |
+
from constant_definitions.game_constants import EVAL_ZERO_FLOAT
|
| 20 |
+
from constant_definitions.var.meta.self_play_constants import (
|
| 21 |
+
SELF_PLAY_COOP_WEIGHT_DENOMINATOR,
|
| 22 |
+
SELF_PLAY_COOP_WEIGHT_NUMERATOR,
|
| 23 |
+
SELF_PLAY_EXPLOIT_WEIGHT_DENOMINATOR,
|
| 24 |
+
SELF_PLAY_EXPLOIT_WEIGHT_NUMERATOR,
|
| 25 |
+
SELF_PLAY_FAIRNESS_WEIGHT_DENOMINATOR,
|
| 26 |
+
SELF_PLAY_FAIRNESS_WEIGHT_NUMERATOR,
|
| 27 |
+
SELF_PLAY_PARETO_WEIGHT_DENOMINATOR,
|
| 28 |
+
SELF_PLAY_PARETO_WEIGHT_NUMERATOR,
|
| 29 |
+
SELF_PLAY_ADAPT_WEIGHT_DENOMINATOR,
|
| 30 |
+
SELF_PLAY_ADAPT_WEIGHT_NUMERATOR,
|
| 31 |
+
SELF_PLAY_OPPONENT_LABEL,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
_ZERO = int()
|
| 37 |
+
_ONE = int(bool(True))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _self_play_weights() -> Dict[str, float]:
|
| 41 |
+
"""Return reward weights tuned for self-play training."""
|
| 42 |
+
return {
|
| 43 |
+
"exploitation_resistance": (
|
| 44 |
+
SELF_PLAY_EXPLOIT_WEIGHT_NUMERATOR
|
| 45 |
+
/ SELF_PLAY_EXPLOIT_WEIGHT_DENOMINATOR
|
| 46 |
+
),
|
| 47 |
+
"cooperation_rate": (
|
| 48 |
+
SELF_PLAY_COOP_WEIGHT_NUMERATOR
|
| 49 |
+
/ SELF_PLAY_COOP_WEIGHT_DENOMINATOR
|
| 50 |
+
),
|
| 51 |
+
"pareto_efficiency": (
|
| 52 |
+
SELF_PLAY_PARETO_WEIGHT_NUMERATOR
|
| 53 |
+
/ SELF_PLAY_PARETO_WEIGHT_DENOMINATOR
|
| 54 |
+
),
|
| 55 |
+
"fairness_index": (
|
| 56 |
+
SELF_PLAY_FAIRNESS_WEIGHT_NUMERATOR
|
| 57 |
+
/ SELF_PLAY_FAIRNESS_WEIGHT_DENOMINATOR
|
| 58 |
+
),
|
| 59 |
+
"adaptability": (
|
| 60 |
+
SELF_PLAY_ADAPT_WEIGHT_NUMERATOR
|
| 61 |
+
/ SELF_PLAY_ADAPT_WEIGHT_DENOMINATOR
|
| 62 |
+
),
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class SelfPlayTrainer:
|
| 67 |
+
"""GRPO training with self-play opponents.
|
| 68 |
+
|
| 69 |
+
Training loop:
|
| 70 |
+
1. Collect trajectories: training model vs frozen opponent
|
| 71 |
+
2. Compute GRPO rewards from episode outcomes
|
| 72 |
+
3. Update training model via TRL GRPOTrainer
|
| 73 |
+
4. Periodically refresh frozen opponent from training model
|
| 74 |
+
5. Add old opponent to pool for diversity
|
| 75 |
+
|
| 76 |
+
Parameters
|
| 77 |
+
----------
|
| 78 |
+
config : SelfPlayConfig
|
| 79 |
+
Training configuration.
|
| 80 |
+
model : object
|
| 81 |
+
HuggingFace model to train.
|
| 82 |
+
tokenizer : object
|
| 83 |
+
Tokenizer for the model.
|
| 84 |
+
env : KantEnvironment, optional
|
| 85 |
+
Game environment instance.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
config: SelfPlayConfig,
|
| 91 |
+
model: object,
|
| 92 |
+
tokenizer: object,
|
| 93 |
+
env: Optional[KantEnvironment] = None,
|
| 94 |
+
) -> None:
|
| 95 |
+
self._config = config
|
| 96 |
+
self._model = model
|
| 97 |
+
self._tokenizer = tokenizer
|
| 98 |
+
self._env = env or KantEnvironment()
|
| 99 |
+
self._pool = OpponentPool(max_size=config.pool_max_size)
|
| 100 |
+
self._frozen = FrozenOpponent.from_model(model, tokenizer)
|
| 101 |
+
self._pool.add(self._frozen)
|
| 102 |
+
self._step_count = _ZERO
|
| 103 |
+
|
| 104 |
+
def _model_generate(self, prompt: str) -> str:
|
| 105 |
+
"""Generate a completion from the training model."""
|
| 106 |
+
import torch
|
| 107 |
+
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
inputs = self._tokenizer(prompt, return_tensors="pt")
|
| 110 |
+
input_len = len(inputs["input_ids"][_ZERO])
|
| 111 |
+
outputs = self._model.generate(
|
| 112 |
+
**inputs,
|
| 113 |
+
max_new_tokens=self._config.max_completion_length,
|
| 114 |
+
)
|
| 115 |
+
return self._tokenizer.decode(
|
| 116 |
+
outputs[_ZERO][input_len:],
|
| 117 |
+
skip_special_tokens=True,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def collect_trajectories(
|
| 121 |
+
self,
|
| 122 |
+
games: List[str],
|
| 123 |
+
num_episodes: int,
|
| 124 |
+
) -> List[EpisodeTrajectory]:
|
| 125 |
+
"""Collect episodes with current frozen opponent."""
|
| 126 |
+
agent = LLMAgent(generate_fn=self._model_generate)
|
| 127 |
+
collector = TrajectoryCollector(
|
| 128 |
+
env=self._env,
|
| 129 |
+
agent=agent,
|
| 130 |
+
reward_fn=lambda ps, os, cr, tr: episode_reward(
|
| 131 |
+
ps, os, cr, tr, weights=_self_play_weights(),
|
| 132 |
+
),
|
| 133 |
+
)
|
| 134 |
+
trajectories: List[EpisodeTrajectory] = []
|
| 135 |
+
for _ep in range(num_episodes):
|
| 136 |
+
game = random.choice(games)
|
| 137 |
+
opponent = self._pool.sample()
|
| 138 |
+
traj = collector.collect_episode(
|
| 139 |
+
game=game,
|
| 140 |
+
strategy=SELF_PLAY_OPPONENT_LABEL,
|
| 141 |
+
opponent_fn=opponent,
|
| 142 |
+
)
|
| 143 |
+
trajectories.append(traj)
|
| 144 |
+
return trajectories
|
| 145 |
+
|
| 146 |
+
def make_reward_fn(self) -> Callable[..., List[float]]:
|
| 147 |
+
"""Create GRPO reward function using self-play episodes."""
|
| 148 |
+
pool = self._pool
|
| 149 |
+
env = self._env
|
| 150 |
+
weights = _self_play_weights()
|
| 151 |
+
|
| 152 |
+
def reward_fn(
|
| 153 |
+
completions: List[str],
|
| 154 |
+
prompts: List[str],
|
| 155 |
+
**kwargs: Any,
|
| 156 |
+
) -> List[float]:
|
| 157 |
+
rewards: List[float] = []
|
| 158 |
+
game_keys = kwargs.get(
|
| 159 |
+
"game_key",
|
| 160 |
+
["prisoners_dilemma"] * len(completions),
|
| 161 |
+
)
|
| 162 |
+
moves_batch = kwargs.get(
|
| 163 |
+
"available_moves",
|
| 164 |
+
[["cooperate", "defect"]] * len(completions),
|
| 165 |
+
)
|
| 166 |
+
for completion, game_key, moves in zip(
|
| 167 |
+
completions, game_keys, moves_batch,
|
| 168 |
+
):
|
| 169 |
+
action_str = parse_action(completion.strip(), moves)
|
| 170 |
+
opponent = pool.sample()
|
| 171 |
+
obs = env.reset(
|
| 172 |
+
game=game_key, opponent_fn=opponent,
|
| 173 |
+
)
|
| 174 |
+
while not obs.done:
|
| 175 |
+
obs = env.step(GameAction(action=action_str))
|
| 176 |
+
reward = episode_reward(
|
| 177 |
+
obs.player_score,
|
| 178 |
+
obs.opponent_score,
|
| 179 |
+
_compute_coop_rate(obs),
|
| 180 |
+
obs.current_round,
|
| 181 |
+
weights=weights,
|
| 182 |
+
)
|
| 183 |
+
rewards.append(reward)
|
| 184 |
+
return rewards
|
| 185 |
+
|
| 186 |
+
return reward_fn
|
| 187 |
+
|
| 188 |
+
def refresh_opponent(self) -> None:
|
| 189 |
+
"""Copy current training model to a new frozen opponent."""
|
| 190 |
+
frozen_model = copy.deepcopy(self._model)
|
| 191 |
+
frozen_model.eval()
|
| 192 |
+
new_opponent = FrozenOpponent.from_model(
|
| 193 |
+
frozen_model, self._tokenizer,
|
| 194 |
+
)
|
| 195 |
+
self._pool.add(new_opponent)
|
| 196 |
+
self._frozen = new_opponent
|
| 197 |
+
logger.info(
|
| 198 |
+
"Refreshed opponent. Pool size: %d", self._pool.size,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def train(self, games: List[str]) -> None:
|
| 202 |
+
"""Main self-play training loop.
|
| 203 |
+
|
| 204 |
+
Parameters
|
| 205 |
+
----------
|
| 206 |
+
games : list of str
|
| 207 |
+
Game keys to train on.
|
| 208 |
+
"""
|
| 209 |
+
from datasets import Dataset
|
| 210 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 211 |
+
import torch
|
| 212 |
+
|
| 213 |
+
trajectories = self.collect_trajectories(
|
| 214 |
+
games, self._config.warmup_episodes,
|
| 215 |
+
)
|
| 216 |
+
samples = []
|
| 217 |
+
for traj in trajectories:
|
| 218 |
+
for step in traj.steps:
|
| 219 |
+
messages = [
|
| 220 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 221 |
+
{"role": "user", "content": step.prompt},
|
| 222 |
+
]
|
| 223 |
+
formatted = self._tokenizer.apply_chat_template(
|
| 224 |
+
messages, tokenize=False,
|
| 225 |
+
add_generation_prompt=True,
|
| 226 |
+
)
|
| 227 |
+
samples.append({
|
| 228 |
+
"prompt": formatted,
|
| 229 |
+
"game_key": traj.game,
|
| 230 |
+
"available_moves": ["cooperate", "defect"],
|
| 231 |
+
})
|
| 232 |
+
dataset = Dataset.from_list(samples)
|
| 233 |
+
|
| 234 |
+
reward_fn = self.make_reward_fn()
|
| 235 |
+
|
| 236 |
+
trl_config = GRPOConfig(
|
| 237 |
+
output_dir=self._config.output_dir,
|
| 238 |
+
num_generations=self._config.num_generations,
|
| 239 |
+
max_completion_length=self._config.max_completion_length,
|
| 240 |
+
per_device_train_batch_size=self._config.batch_size,
|
| 241 |
+
learning_rate=self._config.learning_rate,
|
| 242 |
+
max_steps=self._config.max_steps,
|
| 243 |
+
logging_steps=GRPO_LOG_EVERY,
|
| 244 |
+
save_steps=self._config.opponent_update_interval,
|
| 245 |
+
bf16=torch.cuda.is_available(),
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
trainer = GRPOTrainer(
|
| 249 |
+
model=self._model,
|
| 250 |
+
reward_funcs=reward_fn,
|
| 251 |
+
args=trl_config,
|
| 252 |
+
train_dataset=dataset,
|
| 253 |
+
processing_class=self._tokenizer,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
trainer.train()
|
| 257 |
+
trainer.save_model(self._config.output_dir)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# ---------------------------------------------------------------------------
|
| 261 |
+
# Helpers
|
| 262 |
+
# ---------------------------------------------------------------------------
|
| 263 |
+
|
| 264 |
+
_COOPERATIVE_ACTIONS = frozenset({"cooperate", "stag", "dove"})
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def _compute_coop_rate(obs: GameObservation) -> float:
|
| 268 |
+
"""Fraction of cooperative moves in an episode."""
|
| 269 |
+
if not obs.history:
|
| 270 |
+
return EVAL_ZERO_FLOAT
|
| 271 |
+
total = len(obs.history)
|
| 272 |
+
count = _ZERO
|
| 273 |
+
for rnd in obs.history:
|
| 274 |
+
if rnd.player_action in _COOPERATIVE_ACTIONS:
|
| 275 |
+
count += _ONE
|
| 276 |
+
return count / total
|
train/splits.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Deterministic stratified train/eval game split."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
from typing import Dict, FrozenSet, List, Set, Tuple
|
| 7 |
+
|
| 8 |
+
from common.games_meta.game_tags import GAME_TAGS
|
| 9 |
+
from constant_definitions.batch4.tag_constants import CATEGORIES
|
| 10 |
+
from constant_definitions.game_constants import EVAL_ZERO, EVAL_ONE
|
| 11 |
+
from constant_definitions.train.split_constants import (
|
| 12 |
+
MIN_EVAL_TAG_FRACTION_DENOMINATOR,
|
| 13 |
+
MIN_EVAL_TAG_FRACTION_NUMERATOR,
|
| 14 |
+
SPLIT_SEED,
|
| 15 |
+
TRAIN_FRACTION_DENOMINATOR,
|
| 16 |
+
TRAIN_FRACTION_NUMERATOR,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# Domain tags are used for stratification
|
| 20 |
+
_DOMAIN_TAGS: List[str] = CATEGORIES["domain"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_train_eval_split(
|
| 24 |
+
seed: int = SPLIT_SEED,
|
| 25 |
+
) -> Tuple[FrozenSet[str], FrozenSet[str]]:
|
| 26 |
+
"""Return (train_games, eval_games) as frozen sets of game keys.
|
| 27 |
+
|
| 28 |
+
The split is deterministic for a given seed and stratified so that
|
| 29 |
+
every domain tag has at least ``MIN_EVAL_TAG_FRACTION`` representation
|
| 30 |
+
in the eval set.
|
| 31 |
+
"""
|
| 32 |
+
all_games = sorted(GAME_TAGS.keys())
|
| 33 |
+
rng = random.Random(seed)
|
| 34 |
+
|
| 35 |
+
# Build domain -> games index
|
| 36 |
+
domain_to_games: Dict[str, List[str]] = {tag: [] for tag in _DOMAIN_TAGS}
|
| 37 |
+
for game_key in all_games:
|
| 38 |
+
tags = GAME_TAGS[game_key]
|
| 39 |
+
for dtag in _DOMAIN_TAGS:
|
| 40 |
+
if dtag in tags:
|
| 41 |
+
domain_to_games[dtag].append(game_key)
|
| 42 |
+
|
| 43 |
+
# Guarantee minimum eval representation per domain
|
| 44 |
+
eval_set: Set[str] = set()
|
| 45 |
+
for dtag in _DOMAIN_TAGS:
|
| 46 |
+
games_with_tag = domain_to_games[dtag]
|
| 47 |
+
if not games_with_tag:
|
| 48 |
+
continue
|
| 49 |
+
min_eval = _min_eval_count(len(games_with_tag))
|
| 50 |
+
already_in_eval = [g for g in games_with_tag if g in eval_set]
|
| 51 |
+
needed = min_eval - len(already_in_eval)
|
| 52 |
+
if needed > EVAL_ZERO:
|
| 53 |
+
candidates = [g for g in games_with_tag if g not in eval_set]
|
| 54 |
+
rng.shuffle(candidates)
|
| 55 |
+
for g in candidates[:needed]:
|
| 56 |
+
eval_set.add(g)
|
| 57 |
+
|
| 58 |
+
# Fill remaining eval slots up to target size
|
| 59 |
+
total = len(all_games)
|
| 60 |
+
target_train = (total * TRAIN_FRACTION_NUMERATOR) // TRAIN_FRACTION_DENOMINATOR
|
| 61 |
+
target_eval = total - target_train
|
| 62 |
+
remaining = [g for g in all_games if g not in eval_set]
|
| 63 |
+
rng.shuffle(remaining)
|
| 64 |
+
slots_to_fill = target_eval - len(eval_set)
|
| 65 |
+
if slots_to_fill > EVAL_ZERO:
|
| 66 |
+
for g in remaining[:slots_to_fill]:
|
| 67 |
+
eval_set.add(g)
|
| 68 |
+
|
| 69 |
+
train_set = frozenset(g for g in all_games if g not in eval_set)
|
| 70 |
+
return train_set, frozenset(eval_set)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _min_eval_count(tag_total: int) -> int:
|
| 74 |
+
"""Minimum number of games with a given tag that must be in eval."""
|
| 75 |
+
_numer = tag_total * MIN_EVAL_TAG_FRACTION_NUMERATOR
|
| 76 |
+
result = (_numer + MIN_EVAL_TAG_FRACTION_DENOMINATOR - EVAL_ONE) // MIN_EVAL_TAG_FRACTION_DENOMINATOR
|
| 77 |
+
return max(result, EVAL_ONE)
|
train/train.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""KantBench GRPO Training Script.
|
| 2 |
+
|
| 3 |
+
Trains a language model to play 2-player game theory games optimally
|
| 4 |
+
using Group Relative Policy Optimization (GRPO) via TRL.
|
| 5 |
+
|
| 6 |
+
The KantBench environment runs as a remote OpenEnv server (HF Space):
|
| 7 |
+
- Each GRPO completion is a single move
|
| 8 |
+
- The reward function plays a FULL multi-round episode using that move
|
| 9 |
+
as the agent's consistent strategy via the OpenEnv client
|
| 10 |
+
- The composite reward (payoff + cooperation + Pareto efficiency + fairness)
|
| 11 |
+
becomes the GRPO signal
|
| 12 |
+
|
| 13 |
+
Supports the full KantBench game library including:
|
| 14 |
+
- 90+ base 2-player games and 3 N-player games
|
| 15 |
+
- 9 pre-registered meta-games (rule_proposal, rule_signal, gossip)
|
| 16 |
+
- Dynamic variant composition (cheap_talk, exit, binding_commitment,
|
| 17 |
+
constitutional, proposer_responder, noisy_actions, noisy_payoffs)
|
| 18 |
+
|
| 19 |
+
Usage:
|
| 20 |
+
python -m train.train --model Qwen/Qwen2.5-7B-Instruct --max-steps 200
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import logging
|
| 27 |
+
import random
|
| 28 |
+
from typing import Any, List
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
from datasets import Dataset
|
| 32 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 33 |
+
from transformers import AutoTokenizer
|
| 34 |
+
|
| 35 |
+
from common.games import GAMES
|
| 36 |
+
from common.strategies import STRATEGIES as STRATEGY_REGISTRY
|
| 37 |
+
from spaces.kant.client import KantBenchEnv
|
| 38 |
+
from spaces.kant.models import KantBenchAction, KantBenchObservation
|
| 39 |
+
from train.agent import parse_action
|
| 40 |
+
from train.rewards import episode_reward
|
| 41 |
+
from train.splits import get_train_eval_split
|
| 42 |
+
|
| 43 |
+
logger = logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
# Config
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
|
| 49 |
+
KANTBENCH_URL = "https://openenv-community-kantbench.hf.space"
|
| 50 |
+
|
| 51 |
+
SYSTEM_PROMPT = (
|
| 52 |
+
"You are playing a game-theory game. Analyse the situation and choose "
|
| 53 |
+
"the best action. Respond with ONLY the action name, nothing else."
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Variants that can be dynamically composed on top of base games.
|
| 57 |
+
# These are applied server-side via the variant= reset parameter.
|
| 58 |
+
TRAINABLE_VARIANTS = [
|
| 59 |
+
"cheap_talk",
|
| 60 |
+
"exit",
|
| 61 |
+
"binding_commitment",
|
| 62 |
+
"constitutional",
|
| 63 |
+
"noisy_actions",
|
| 64 |
+
"noisy_payoffs",
|
| 65 |
+
"rule_proposal",
|
| 66 |
+
"rule_signal",
|
| 67 |
+
"gossip",
|
| 68 |
+
]
|
| 69 |
+
|
| 70 |
+
# Base games suitable for variant composition (2-player matrix games).
|
| 71 |
+
VARIANT_BASE_GAMES = [
|
| 72 |
+
"prisoners_dilemma",
|
| 73 |
+
"stag_hunt",
|
| 74 |
+
"hawk_dove",
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
# Fraction of dataset samples that use dynamic variant composition.
|
| 78 |
+
VARIANT_FRACTION = 0.3
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
# Helpers to bridge KantBenchObservation -> training code
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _obs_cooperation_rate(obs: KantBenchObservation) -> float:
|
| 87 |
+
"""Compute cooperation rate from a KantBenchObservation's history."""
|
| 88 |
+
if not obs.history:
|
| 89 |
+
return 0.0
|
| 90 |
+
coop_actions = {"cooperate", "stag", "dove", "contribute"}
|
| 91 |
+
coop_count = sum(
|
| 92 |
+
1 for h in obs.history
|
| 93 |
+
if any(ca in h.get("your_move", "") for ca in coop_actions)
|
| 94 |
+
)
|
| 95 |
+
return coop_count / len(obs.history)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _build_prompt(obs: KantBenchObservation) -> str:
|
| 99 |
+
"""Build a structured prompt from a KantBenchObservation.
|
| 100 |
+
|
| 101 |
+
Mirrors PromptBuilder.build() but works with the OpenEnv client's
|
| 102 |
+
observation format.
|
| 103 |
+
"""
|
| 104 |
+
sections: list[str] = []
|
| 105 |
+
|
| 106 |
+
# Game section
|
| 107 |
+
sections.append(
|
| 108 |
+
f"[Game]\n{obs.game_name}\n{obs.game_description}"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# History section
|
| 112 |
+
if obs.history:
|
| 113 |
+
history_lines: list[str] = []
|
| 114 |
+
for h in obs.history[-5:]: # Last 5 rounds
|
| 115 |
+
line = (
|
| 116 |
+
f"Round {h.get('round', '?')}"
|
| 117 |
+
f" | You played: {h.get('your_move', '?')}"
|
| 118 |
+
f" | Opponent played: {h.get('opponent_move', '?')}"
|
| 119 |
+
f" | Your payoff: {h.get('your_payoff', '?')}"
|
| 120 |
+
f" | Opp payoff: {h.get('opponent_payoff', '?')}"
|
| 121 |
+
)
|
| 122 |
+
history_lines.append(line)
|
| 123 |
+
sections.append("[History]\n" + "\n".join(history_lines))
|
| 124 |
+
|
| 125 |
+
# Scores section
|
| 126 |
+
sections.append(
|
| 127 |
+
f"[Scores]\nYour score: {obs.cumulative_score}"
|
| 128 |
+
f"\nRound: {obs.round_number} of {obs.max_rounds}"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Available actions
|
| 132 |
+
action_lines = [f"- {a}" for a in obs.available_moves]
|
| 133 |
+
sections.append("[Available Actions]\n" + "\n".join(action_lines))
|
| 134 |
+
|
| 135 |
+
# Instruction
|
| 136 |
+
sections.append(f"[Instruction]\n{SYSTEM_PROMPT}")
|
| 137 |
+
|
| 138 |
+
return "\n\n".join(sections)
|
| 139 |
+
|
| 140 |
+
# ---------------------------------------------------------------------------
|
| 141 |
+
# Dataset generation using PromptBuilder
|
| 142 |
+
# ---------------------------------------------------------------------------
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def build_dataset(
|
| 146 |
+
base_url: str,
|
| 147 |
+
n_samples: int = 1000,
|
| 148 |
+
games: list[str] | None = None,
|
| 149 |
+
strategies: list[str] | None = None,
|
| 150 |
+
variant_fraction: float = VARIANT_FRACTION,
|
| 151 |
+
) -> Dataset:
|
| 152 |
+
"""Generate diverse game theory prompts for GRPO training.
|
| 153 |
+
|
| 154 |
+
Connects to the KantBench OpenEnv server to generate real observations,
|
| 155 |
+
then builds structured prompts from diverse game states.
|
| 156 |
+
|
| 157 |
+
A fraction of samples use dynamic variant composition (cheap_talk,
|
| 158 |
+
constitutional, gossip, etc.) to train on meta-gaming scenarios.
|
| 159 |
+
"""
|
| 160 |
+
game_keys = games or list(GAMES.keys())
|
| 161 |
+
strat_names = strategies or list(STRATEGY_REGISTRY.keys())
|
| 162 |
+
samples = []
|
| 163 |
+
|
| 164 |
+
with KantBenchEnv(base_url=base_url) as env:
|
| 165 |
+
attempts = 0
|
| 166 |
+
while len(samples) < n_samples:
|
| 167 |
+
attempts += 1
|
| 168 |
+
|
| 169 |
+
# Decide whether to use a variant
|
| 170 |
+
use_variant = random.random() < variant_fraction
|
| 171 |
+
if use_variant:
|
| 172 |
+
game_key = random.choice(VARIANT_BASE_GAMES)
|
| 173 |
+
variant = random.choice(TRAINABLE_VARIANTS)
|
| 174 |
+
else:
|
| 175 |
+
game_key = random.choice(game_keys)
|
| 176 |
+
variant = None
|
| 177 |
+
|
| 178 |
+
strategy = random.choice(strat_names)
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
# Reset env — pass variant for dynamic composition
|
| 182 |
+
reset_kwargs = {"game": game_key, "strategy": strategy}
|
| 183 |
+
if variant:
|
| 184 |
+
reset_kwargs["variant"] = variant
|
| 185 |
+
|
| 186 |
+
result = env.reset(**reset_kwargs)
|
| 187 |
+
obs = result.observation
|
| 188 |
+
|
| 189 |
+
# Play 0..N-1 random rounds to create diverse game states
|
| 190 |
+
max_rounds = obs.max_rounds
|
| 191 |
+
rounds_to_play = random.randint(0, max(max_rounds - 1, 0))
|
| 192 |
+
for _ in range(rounds_to_play):
|
| 193 |
+
move = random.choice(obs.available_moves)
|
| 194 |
+
result = env.step(KantBenchAction(move=move))
|
| 195 |
+
obs = result.observation
|
| 196 |
+
if result.done:
|
| 197 |
+
break
|
| 198 |
+
|
| 199 |
+
if result.done:
|
| 200 |
+
# Replay without filling all rounds
|
| 201 |
+
result = env.reset(**reset_kwargs)
|
| 202 |
+
obs = result.observation
|
| 203 |
+
|
| 204 |
+
prompt = _build_prompt(obs)
|
| 205 |
+
|
| 206 |
+
samples.append({
|
| 207 |
+
"prompt": prompt,
|
| 208 |
+
"game_key": game_key,
|
| 209 |
+
"strategy": strategy,
|
| 210 |
+
"variant": variant or "",
|
| 211 |
+
"available_moves": list(obs.available_moves),
|
| 212 |
+
"rounds_remaining": obs.max_rounds - obs.round_number,
|
| 213 |
+
})
|
| 214 |
+
except (RuntimeError, ConnectionError, Exception) as exc:
|
| 215 |
+
logger.debug(
|
| 216 |
+
"Skipping %s/%s (variant=%s): %s",
|
| 217 |
+
game_key, strategy, variant, exc,
|
| 218 |
+
)
|
| 219 |
+
continue
|
| 220 |
+
|
| 221 |
+
return Dataset.from_list(samples)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# ---------------------------------------------------------------------------
|
| 225 |
+
# Reward function — full episode rollout
|
| 226 |
+
# ---------------------------------------------------------------------------
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def make_reward_fn(base_url: str):
|
| 230 |
+
"""Returns a GRPO reward function that plays full episodes via OpenEnv.
|
| 231 |
+
|
| 232 |
+
For each completion:
|
| 233 |
+
1. Parse the move from the LLM output
|
| 234 |
+
2. Reset the KantBench server with the correct game/strategy/variant
|
| 235 |
+
3. Play the FULL episode using the parsed move as a consistent strategy
|
| 236 |
+
4. Compute composite reward: payoff + cooperation + Pareto + fairness
|
| 237 |
+
"""
|
| 238 |
+
env = KantBenchEnv(base_url=base_url)
|
| 239 |
+
env.connect()
|
| 240 |
+
|
| 241 |
+
def reward_fn(
|
| 242 |
+
completions: list[str],
|
| 243 |
+
prompts: list[str],
|
| 244 |
+
**kwargs: Any,
|
| 245 |
+
) -> list[float]:
|
| 246 |
+
rewards = []
|
| 247 |
+
game_keys = kwargs.get("game_key", ["prisoners_dilemma"] * len(completions))
|
| 248 |
+
strategies = kwargs.get("strategy", ["tit_for_tat"] * len(completions))
|
| 249 |
+
variants = kwargs.get("variant", [""] * len(completions))
|
| 250 |
+
available_moves_batch = kwargs.get(
|
| 251 |
+
"available_moves", [["cooperate", "defect"]] * len(completions)
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
for completion, game_key, strategy, variant, moves in zip(
|
| 255 |
+
completions, game_keys, strategies, variants, available_moves_batch
|
| 256 |
+
):
|
| 257 |
+
# Parse move from LLM output
|
| 258 |
+
action_str = parse_action(completion.strip(), moves)
|
| 259 |
+
|
| 260 |
+
try:
|
| 261 |
+
# Play a full episode using this move as a consistent strategy
|
| 262 |
+
reset_kwargs = {"game": game_key, "strategy": strategy}
|
| 263 |
+
if variant:
|
| 264 |
+
reset_kwargs["variant"] = variant
|
| 265 |
+
|
| 266 |
+
result = env.reset(**reset_kwargs)
|
| 267 |
+
while not result.done:
|
| 268 |
+
result = env.step(KantBenchAction(move=action_str))
|
| 269 |
+
|
| 270 |
+
obs = result.observation
|
| 271 |
+
|
| 272 |
+
# Compute cooperation rate from observation history
|
| 273 |
+
coop_rate = _obs_cooperation_rate(obs)
|
| 274 |
+
|
| 275 |
+
# Composite reward from the reward module
|
| 276 |
+
# opponent_score not directly available in KantBenchObservation,
|
| 277 |
+
# approximate from history
|
| 278 |
+
opp_score = sum(
|
| 279 |
+
h.get("opponent_payoff", 0.0) for h in obs.history
|
| 280 |
+
)
|
| 281 |
+
reward = episode_reward(
|
| 282 |
+
player_score=obs.cumulative_score,
|
| 283 |
+
opponent_score=opp_score,
|
| 284 |
+
cooperation_rate=coop_rate,
|
| 285 |
+
total_rounds=obs.round_number,
|
| 286 |
+
)
|
| 287 |
+
rewards.append(reward)
|
| 288 |
+
|
| 289 |
+
except (ValueError, KeyError, RuntimeError, ConnectionError) as exc:
|
| 290 |
+
logger.debug("Reward error for %s/%s: %s", game_key, action_str, exc)
|
| 291 |
+
rewards.append(-1.0)
|
| 292 |
+
|
| 293 |
+
return rewards
|
| 294 |
+
|
| 295 |
+
return reward_fn
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
# ---------------------------------------------------------------------------
|
| 299 |
+
# Main
|
| 300 |
+
# ---------------------------------------------------------------------------
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def parse_args():
|
| 304 |
+
p = argparse.ArgumentParser(description="KantBench GRPO Training")
|
| 305 |
+
p.add_argument("--model", default="Qwen/Qwen2.5-7B-Instruct")
|
| 306 |
+
p.add_argument("--output-dir", default="./kantbench-grpo")
|
| 307 |
+
p.add_argument("--env-url", default=KANTBENCH_URL,
|
| 308 |
+
help="KantBench OpenEnv server URL")
|
| 309 |
+
p.add_argument("--episodes", type=int, default=1000, help="Training dataset size")
|
| 310 |
+
p.add_argument("--num-generations", type=int, default=8, help="GRPO group size")
|
| 311 |
+
p.add_argument("--batch-size", type=int, default=4)
|
| 312 |
+
p.add_argument("--grad-accum", type=int, default=4)
|
| 313 |
+
p.add_argument("--lr", type=float, default=5e-6)
|
| 314 |
+
p.add_argument("--max-steps", type=int, default=500)
|
| 315 |
+
p.add_argument("--report-to", default="wandb", help="wandb, tensorboard, or none")
|
| 316 |
+
p.add_argument("--push-to-hub", action="store_true")
|
| 317 |
+
p.add_argument("--hub-model-id", default="jtowarek/kantbench-qwen2.5-7b")
|
| 318 |
+
p.add_argument("--use-train-split", action="store_true",
|
| 319 |
+
help="Use stratified train/eval split (eval games held out)")
|
| 320 |
+
p.add_argument("--variant-fraction", type=float, default=VARIANT_FRACTION,
|
| 321 |
+
help="Fraction of samples using dynamic variant composition")
|
| 322 |
+
return p.parse_args()
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def main():
|
| 326 |
+
args = parse_args()
|
| 327 |
+
logging.basicConfig(level=logging.INFO)
|
| 328 |
+
|
| 329 |
+
print(f"Loading model: {args.model}")
|
| 330 |
+
print(f"Output: {args.output_dir}")
|
| 331 |
+
print(f"OpenEnv server: {args.env_url}")
|
| 332 |
+
|
| 333 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
| 334 |
+
if tokenizer.pad_token is None:
|
| 335 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 336 |
+
|
| 337 |
+
# Optionally use stratified train/eval split
|
| 338 |
+
train_games = None
|
| 339 |
+
if args.use_train_split:
|
| 340 |
+
train_set, eval_set = get_train_eval_split()
|
| 341 |
+
train_games = sorted(train_set)
|
| 342 |
+
print(f"Using stratified split: {len(train_games)} train, {len(eval_set)} eval games")
|
| 343 |
+
|
| 344 |
+
dataset = build_dataset(
|
| 345 |
+
args.env_url, args.episodes, games=train_games,
|
| 346 |
+
variant_fraction=args.variant_fraction,
|
| 347 |
+
)
|
| 348 |
+
variant_count = sum(1 for v in dataset["variant"] if v)
|
| 349 |
+
print(f"Dataset: {len(dataset)} prompts across {len(GAMES)} games")
|
| 350 |
+
print(f" Variant samples: {variant_count} ({variant_count*100//max(len(dataset),1)}%)")
|
| 351 |
+
|
| 352 |
+
# Format prompts with chat template
|
| 353 |
+
def format_prompt(example):
|
| 354 |
+
messages = [
|
| 355 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 356 |
+
{"role": "user", "content": example["prompt"]},
|
| 357 |
+
]
|
| 358 |
+
return {
|
| 359 |
+
"prompt": tokenizer.apply_chat_template(
|
| 360 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 361 |
+
)
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
dataset = dataset.map(format_prompt)
|
| 365 |
+
|
| 366 |
+
reward_fn = make_reward_fn(args.env_url)
|
| 367 |
+
|
| 368 |
+
config = GRPOConfig(
|
| 369 |
+
output_dir=args.output_dir,
|
| 370 |
+
num_generations=args.num_generations,
|
| 371 |
+
max_completion_length=32,
|
| 372 |
+
per_device_train_batch_size=args.batch_size,
|
| 373 |
+
gradient_accumulation_steps=args.grad_accum,
|
| 374 |
+
learning_rate=args.lr,
|
| 375 |
+
max_steps=args.max_steps,
|
| 376 |
+
logging_steps=10,
|
| 377 |
+
save_steps=100,
|
| 378 |
+
bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8,
|
| 379 |
+
fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
|
| 380 |
+
report_to=args.report_to,
|
| 381 |
+
push_to_hub=args.push_to_hub,
|
| 382 |
+
hub_model_id=args.hub_model_id if args.push_to_hub else None,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
trainer = GRPOTrainer(
|
| 386 |
+
model=args.model,
|
| 387 |
+
reward_funcs=reward_fn,
|
| 388 |
+
args=config,
|
| 389 |
+
train_dataset=dataset,
|
| 390 |
+
processing_class=tokenizer,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
print("Starting GRPO training...")
|
| 394 |
+
print(f" Reward: composite (payoff + cooperation + Pareto + fairness)")
|
| 395 |
+
print(f" Episode: full multi-round rollout via OpenEnv @ {args.env_url}")
|
| 396 |
+
print(f" Variants: {args.variant_fraction*100:.0f}% of samples use dynamic composition")
|
| 397 |
+
trainer.train()
|
| 398 |
+
trainer.save_model(args.output_dir)
|
| 399 |
+
print(f"Done. Model saved to {args.output_dir}")
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
if __name__ == "__main__":
|
| 403 |
+
main()
|
train/trajectory.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Trajectory collection for training data generation."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
from env.models import GameAction, GameObservation, RoundResult
|
| 9 |
+
from env.environment import KantEnvironment
|
| 10 |
+
from constant_definitions.game_constants import EVAL_ZERO_FLOAT
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class StepRecord:
|
| 15 |
+
"""A single step within an episode trajectory."""
|
| 16 |
+
|
| 17 |
+
prompt: str
|
| 18 |
+
completion: str
|
| 19 |
+
action: str
|
| 20 |
+
reward: float
|
| 21 |
+
player_payoff: float
|
| 22 |
+
opponent_payoff: float
|
| 23 |
+
round_number: int
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class EpisodeTrajectory:
|
| 28 |
+
"""Complete trajectory of one episode."""
|
| 29 |
+
|
| 30 |
+
game: str
|
| 31 |
+
strategy: str
|
| 32 |
+
steps: List[StepRecord] = field(default_factory=list)
|
| 33 |
+
episode_reward: float = EVAL_ZERO_FLOAT
|
| 34 |
+
player_score: float = EVAL_ZERO_FLOAT
|
| 35 |
+
opponent_score: float = EVAL_ZERO_FLOAT
|
| 36 |
+
cooperation_rate: float = EVAL_ZERO_FLOAT
|
| 37 |
+
rounds_played: int = int()
|
| 38 |
+
metrics: Dict[str, float] = field(default_factory=dict)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TrajectoryCollector:
|
| 42 |
+
"""Runs episodes and collects trajectories for training.
|
| 43 |
+
|
| 44 |
+
Parameters
|
| 45 |
+
----------
|
| 46 |
+
env : KantEnvironment
|
| 47 |
+
The game environment instance.
|
| 48 |
+
agent : LLMAgent
|
| 49 |
+
An agent with ``last_prompt`` / ``last_completion`` properties,
|
| 50 |
+
callable with ``(GameObservation) -> GameAction``.
|
| 51 |
+
reward_fn : callable, optional
|
| 52 |
+
Function(player_score, opponent_score, cooperation_rate, rounds) -> float.
|
| 53 |
+
step_reward_fn : callable, optional
|
| 54 |
+
Function(player_payoff, opponent_payoff, payoff_min, payoff_max) -> float.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
env: KantEnvironment,
|
| 60 |
+
agent: Any,
|
| 61 |
+
reward_fn: Optional[Callable[..., float]] = None,
|
| 62 |
+
step_reward_fn: Optional[Callable[..., float]] = None,
|
| 63 |
+
) -> None:
|
| 64 |
+
self._env = env
|
| 65 |
+
self._agent = agent
|
| 66 |
+
self._reward_fn = reward_fn
|
| 67 |
+
self._step_reward_fn = step_reward_fn
|
| 68 |
+
|
| 69 |
+
def collect_episode(
|
| 70 |
+
self,
|
| 71 |
+
game: str,
|
| 72 |
+
strategy: str = "tit_for_tat",
|
| 73 |
+
opponent_fn: Optional[Callable] = None,
|
| 74 |
+
) -> EpisodeTrajectory:
|
| 75 |
+
"""Run a single episode and return its trajectory."""
|
| 76 |
+
if opponent_fn is not None:
|
| 77 |
+
obs = self._env.reset(game=game, opponent_fn=opponent_fn)
|
| 78 |
+
else:
|
| 79 |
+
obs = self._env.reset(game=game, strategy=strategy)
|
| 80 |
+
steps: List[StepRecord] = []
|
| 81 |
+
|
| 82 |
+
while not obs.done:
|
| 83 |
+
action = self._agent(obs)
|
| 84 |
+
|
| 85 |
+
# Capture prompt/completion from agent
|
| 86 |
+
prompt = getattr(self._agent, "last_prompt", "")
|
| 87 |
+
completion = getattr(self._agent, "last_completion", "")
|
| 88 |
+
|
| 89 |
+
next_obs = self._env.step(action)
|
| 90 |
+
|
| 91 |
+
# Compute step reward
|
| 92 |
+
step_reward = EVAL_ZERO_FLOAT
|
| 93 |
+
if self._step_reward_fn is not None and next_obs.last_round is not None:
|
| 94 |
+
step_reward = self._step_reward_fn(
|
| 95 |
+
next_obs.last_round.player_payoff,
|
| 96 |
+
next_obs.last_round.opponent_payoff,
|
| 97 |
+
EVAL_ZERO_FLOAT,
|
| 98 |
+
EVAL_ZERO_FLOAT,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Record step
|
| 102 |
+
last_rnd = next_obs.last_round
|
| 103 |
+
steps.append(StepRecord(
|
| 104 |
+
prompt=prompt,
|
| 105 |
+
completion=completion,
|
| 106 |
+
action=action.action,
|
| 107 |
+
reward=step_reward,
|
| 108 |
+
player_payoff=(
|
| 109 |
+
last_rnd.player_payoff if last_rnd is not None
|
| 110 |
+
else EVAL_ZERO_FLOAT
|
| 111 |
+
),
|
| 112 |
+
opponent_payoff=(
|
| 113 |
+
last_rnd.opponent_payoff if last_rnd is not None
|
| 114 |
+
else EVAL_ZERO_FLOAT
|
| 115 |
+
),
|
| 116 |
+
round_number=next_obs.current_round,
|
| 117 |
+
))
|
| 118 |
+
obs = next_obs
|
| 119 |
+
|
| 120 |
+
# Compute cooperation rate (reusing tournament logic pattern)
|
| 121 |
+
coop_rate = _compute_cooperation_rate(obs)
|
| 122 |
+
|
| 123 |
+
# Compute episode reward
|
| 124 |
+
ep_reward = EVAL_ZERO_FLOAT
|
| 125 |
+
if self._reward_fn is not None:
|
| 126 |
+
ep_reward = self._reward_fn(
|
| 127 |
+
obs.player_score,
|
| 128 |
+
obs.opponent_score,
|
| 129 |
+
coop_rate,
|
| 130 |
+
obs.current_round,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
return EpisodeTrajectory(
|
| 134 |
+
game=game,
|
| 135 |
+
strategy=strategy,
|
| 136 |
+
steps=steps,
|
| 137 |
+
episode_reward=ep_reward,
|
| 138 |
+
player_score=obs.player_score,
|
| 139 |
+
opponent_score=obs.opponent_score,
|
| 140 |
+
cooperation_rate=coop_rate,
|
| 141 |
+
rounds_played=obs.current_round,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def collect_batch(
|
| 145 |
+
self,
|
| 146 |
+
games: List[str],
|
| 147 |
+
strategies: Optional[List[str]] = None,
|
| 148 |
+
episodes_per_pair: int = int(bool(True)),
|
| 149 |
+
opponent_fn: Optional[Callable] = None,
|
| 150 |
+
) -> List[EpisodeTrajectory]:
|
| 151 |
+
"""Collect trajectories for all (game, strategy) combinations.
|
| 152 |
+
|
| 153 |
+
If *opponent_fn* is provided, self-play mode is used: only
|
| 154 |
+
games are iterated (strategies are ignored).
|
| 155 |
+
"""
|
| 156 |
+
trajectories: List[EpisodeTrajectory] = []
|
| 157 |
+
if opponent_fn is not None:
|
| 158 |
+
for game in games:
|
| 159 |
+
for _ep in range(episodes_per_pair):
|
| 160 |
+
traj = self.collect_episode(
|
| 161 |
+
game, opponent_fn=opponent_fn,
|
| 162 |
+
)
|
| 163 |
+
trajectories.append(traj)
|
| 164 |
+
else:
|
| 165 |
+
strats = strategies or ["tit_for_tat"]
|
| 166 |
+
for game in games:
|
| 167 |
+
for strategy in strats:
|
| 168 |
+
for _ep in range(episodes_per_pair):
|
| 169 |
+
traj = self.collect_episode(game, strategy)
|
| 170 |
+
trajectories.append(traj)
|
| 171 |
+
return trajectories
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ---------------------------------------------------------------------------
|
| 175 |
+
# Helpers
|
| 176 |
+
# ---------------------------------------------------------------------------
|
| 177 |
+
|
| 178 |
+
_COOPERATIVE_ACTIONS = frozenset({"cooperate", "stag", "dove"})
|
| 179 |
+
_ECONOMIC_PREFIXES = frozenset({"offer", "invest", "contribute"})
|
| 180 |
+
|
| 181 |
+
_ZERO = int()
|
| 182 |
+
_ONE = int(bool(True))
|
| 183 |
+
_TWO = _ONE + _ONE
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _compute_cooperation_rate(obs: GameObservation) -> float:
|
| 187 |
+
"""Fraction of cooperative moves in an episode."""
|
| 188 |
+
if not obs.history:
|
| 189 |
+
return EVAL_ZERO_FLOAT
|
| 190 |
+
total = len(obs.history)
|
| 191 |
+
cooperative_count = _ZERO
|
| 192 |
+
first_action = obs.history[_ZERO].player_action
|
| 193 |
+
prefix = first_action.split("_")[_ZERO]
|
| 194 |
+
is_economic = prefix in _ECONOMIC_PREFIXES
|
| 195 |
+
if is_economic:
|
| 196 |
+
median_idx = len(obs.available_actions) // _TWO
|
| 197 |
+
for rnd in obs.history:
|
| 198 |
+
act = rnd.player_action
|
| 199 |
+
if act in obs.available_actions:
|
| 200 |
+
if obs.available_actions.index(act) >= median_idx:
|
| 201 |
+
cooperative_count += _ONE
|
| 202 |
+
else:
|
| 203 |
+
for rnd in obs.history:
|
| 204 |
+
if rnd.player_action in _COOPERATIVE_ACTIONS:
|
| 205 |
+
cooperative_count += _ONE
|
| 206 |
+
return cooperative_count / total
|