jtowarek commited on
Commit
141582a
·
verified ·
1 Parent(s): 5056447

Delete train/self_play/trainer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train/self_play/trainer.py +0 -276
train/self_play/trainer.py DELETED
@@ -1,276 +0,0 @@
1
- """Self-play GRPO trainer for multi-agent training."""
2
-
3
- from __future__ import annotations
4
-
5
- import copy
6
- import logging
7
- import random
8
- from typing import Any, Callable, Dict, List, Optional
9
-
10
- from env.environment import KantEnvironment
11
- from env.models import GameAction, GameObservation
12
- from train.agent import LLMAgent, PromptBuilder, parse_action
13
- from train.rewards import episode_reward
14
- from train.trajectory import TrajectoryCollector, EpisodeTrajectory
15
- from train.self_play.opponents import FrozenOpponent, OpponentPool
16
- from train.self_play.config import SelfPlayConfig
17
- from constant_definitions.train.agent_constants import SYSTEM_PROMPT
18
- from constant_definitions.train.grpo_constants import GRPO_LOG_EVERY
19
- from constant_definitions.game_constants import EVAL_ZERO_FLOAT
20
- from constant_definitions.var.meta.self_play_constants import (
21
- SELF_PLAY_COOP_WEIGHT_DENOMINATOR,
22
- SELF_PLAY_COOP_WEIGHT_NUMERATOR,
23
- SELF_PLAY_EXPLOIT_WEIGHT_DENOMINATOR,
24
- SELF_PLAY_EXPLOIT_WEIGHT_NUMERATOR,
25
- SELF_PLAY_FAIRNESS_WEIGHT_DENOMINATOR,
26
- SELF_PLAY_FAIRNESS_WEIGHT_NUMERATOR,
27
- SELF_PLAY_PARETO_WEIGHT_DENOMINATOR,
28
- SELF_PLAY_PARETO_WEIGHT_NUMERATOR,
29
- SELF_PLAY_ADAPT_WEIGHT_DENOMINATOR,
30
- SELF_PLAY_ADAPT_WEIGHT_NUMERATOR,
31
- SELF_PLAY_OPPONENT_LABEL,
32
- )
33
-
34
- logger = logging.getLogger(__name__)
35
-
36
- _ZERO = int()
37
- _ONE = int(bool(True))
38
-
39
-
40
- def _self_play_weights() -> Dict[str, float]:
41
- """Return reward weights tuned for self-play training."""
42
- return {
43
- "exploitation_resistance": (
44
- SELF_PLAY_EXPLOIT_WEIGHT_NUMERATOR
45
- / SELF_PLAY_EXPLOIT_WEIGHT_DENOMINATOR
46
- ),
47
- "cooperation_rate": (
48
- SELF_PLAY_COOP_WEIGHT_NUMERATOR
49
- / SELF_PLAY_COOP_WEIGHT_DENOMINATOR
50
- ),
51
- "pareto_efficiency": (
52
- SELF_PLAY_PARETO_WEIGHT_NUMERATOR
53
- / SELF_PLAY_PARETO_WEIGHT_DENOMINATOR
54
- ),
55
- "fairness_index": (
56
- SELF_PLAY_FAIRNESS_WEIGHT_NUMERATOR
57
- / SELF_PLAY_FAIRNESS_WEIGHT_DENOMINATOR
58
- ),
59
- "adaptability": (
60
- SELF_PLAY_ADAPT_WEIGHT_NUMERATOR
61
- / SELF_PLAY_ADAPT_WEIGHT_DENOMINATOR
62
- ),
63
- }
64
-
65
-
66
- class SelfPlayTrainer:
67
- """GRPO training with self-play opponents.
68
-
69
- Training loop:
70
- 1. Collect trajectories: training model vs frozen opponent
71
- 2. Compute GRPO rewards from episode outcomes
72
- 3. Update training model via TRL GRPOTrainer
73
- 4. Periodically refresh frozen opponent from training model
74
- 5. Add old opponent to pool for diversity
75
-
76
- Parameters
77
- ----------
78
- config : SelfPlayConfig
79
- Training configuration.
80
- model : object
81
- HuggingFace model to train.
82
- tokenizer : object
83
- Tokenizer for the model.
84
- env : KantEnvironment, optional
85
- Game environment instance.
86
- """
87
-
88
- def __init__(
89
- self,
90
- config: SelfPlayConfig,
91
- model: object,
92
- tokenizer: object,
93
- env: Optional[KantEnvironment] = None,
94
- ) -> None:
95
- self._config = config
96
- self._model = model
97
- self._tokenizer = tokenizer
98
- self._env = env or KantEnvironment()
99
- self._pool = OpponentPool(max_size=config.pool_max_size)
100
- self._frozen = FrozenOpponent.from_model(model, tokenizer)
101
- self._pool.add(self._frozen)
102
- self._step_count = _ZERO
103
-
104
- def _model_generate(self, prompt: str) -> str:
105
- """Generate a completion from the training model."""
106
- import torch
107
-
108
- with torch.no_grad():
109
- inputs = self._tokenizer(prompt, return_tensors="pt")
110
- input_len = len(inputs["input_ids"][_ZERO])
111
- outputs = self._model.generate(
112
- **inputs,
113
- max_new_tokens=self._config.max_completion_length,
114
- )
115
- return self._tokenizer.decode(
116
- outputs[_ZERO][input_len:],
117
- skip_special_tokens=True,
118
- )
119
-
120
- def collect_trajectories(
121
- self,
122
- games: List[str],
123
- num_episodes: int,
124
- ) -> List[EpisodeTrajectory]:
125
- """Collect episodes with current frozen opponent."""
126
- agent = LLMAgent(generate_fn=self._model_generate)
127
- collector = TrajectoryCollector(
128
- env=self._env,
129
- agent=agent,
130
- reward_fn=lambda ps, os, cr, tr: episode_reward(
131
- ps, os, cr, tr, weights=_self_play_weights(),
132
- ),
133
- )
134
- trajectories: List[EpisodeTrajectory] = []
135
- for _ep in range(num_episodes):
136
- game = random.choice(games)
137
- opponent = self._pool.sample()
138
- traj = collector.collect_episode(
139
- game=game,
140
- strategy=SELF_PLAY_OPPONENT_LABEL,
141
- opponent_fn=opponent,
142
- )
143
- trajectories.append(traj)
144
- return trajectories
145
-
146
- def make_reward_fn(self) -> Callable[..., List[float]]:
147
- """Create GRPO reward function using self-play episodes."""
148
- pool = self._pool
149
- env = self._env
150
- weights = _self_play_weights()
151
-
152
- def reward_fn(
153
- completions: List[str],
154
- prompts: List[str],
155
- **kwargs: Any,
156
- ) -> List[float]:
157
- rewards: List[float] = []
158
- game_keys = kwargs.get(
159
- "game_key",
160
- ["prisoners_dilemma"] * len(completions),
161
- )
162
- moves_batch = kwargs.get(
163
- "available_moves",
164
- [["cooperate", "defect"]] * len(completions),
165
- )
166
- for completion, game_key, moves in zip(
167
- completions, game_keys, moves_batch,
168
- ):
169
- action_str = parse_action(completion.strip(), moves)
170
- opponent = pool.sample()
171
- obs = env.reset(
172
- game=game_key, opponent_fn=opponent,
173
- )
174
- while not obs.done:
175
- obs = env.step(GameAction(action=action_str))
176
- reward = episode_reward(
177
- obs.player_score,
178
- obs.opponent_score,
179
- _compute_coop_rate(obs),
180
- obs.current_round,
181
- weights=weights,
182
- )
183
- rewards.append(reward)
184
- return rewards
185
-
186
- return reward_fn
187
-
188
- def refresh_opponent(self) -> None:
189
- """Copy current training model to a new frozen opponent."""
190
- frozen_model = copy.deepcopy(self._model)
191
- frozen_model.eval()
192
- new_opponent = FrozenOpponent.from_model(
193
- frozen_model, self._tokenizer,
194
- )
195
- self._pool.add(new_opponent)
196
- self._frozen = new_opponent
197
- logger.info(
198
- "Refreshed opponent. Pool size: %d", self._pool.size,
199
- )
200
-
201
- def train(self, games: List[str]) -> None:
202
- """Main self-play training loop.
203
-
204
- Parameters
205
- ----------
206
- games : list of str
207
- Game keys to train on.
208
- """
209
- from datasets import Dataset
210
- from trl import GRPOConfig, GRPOTrainer
211
- import torch
212
-
213
- trajectories = self.collect_trajectories(
214
- games, self._config.warmup_episodes,
215
- )
216
- samples = []
217
- for traj in trajectories:
218
- for step in traj.steps:
219
- messages = [
220
- {"role": "system", "content": SYSTEM_PROMPT},
221
- {"role": "user", "content": step.prompt},
222
- ]
223
- formatted = self._tokenizer.apply_chat_template(
224
- messages, tokenize=False,
225
- add_generation_prompt=True,
226
- )
227
- samples.append({
228
- "prompt": formatted,
229
- "game_key": traj.game,
230
- "available_moves": ["cooperate", "defect"],
231
- })
232
- dataset = Dataset.from_list(samples)
233
-
234
- reward_fn = self.make_reward_fn()
235
-
236
- trl_config = GRPOConfig(
237
- output_dir=self._config.output_dir,
238
- num_generations=self._config.num_generations,
239
- max_completion_length=self._config.max_completion_length,
240
- per_device_train_batch_size=self._config.batch_size,
241
- learning_rate=self._config.learning_rate,
242
- max_steps=self._config.max_steps,
243
- logging_steps=GRPO_LOG_EVERY,
244
- save_steps=self._config.opponent_update_interval,
245
- bf16=torch.cuda.is_available(),
246
- )
247
-
248
- trainer = GRPOTrainer(
249
- model=self._model,
250
- reward_funcs=reward_fn,
251
- args=trl_config,
252
- train_dataset=dataset,
253
- processing_class=self._tokenizer,
254
- )
255
-
256
- trainer.train()
257
- trainer.save_model(self._config.output_dir)
258
-
259
-
260
- # ---------------------------------------------------------------------------
261
- # Helpers
262
- # ---------------------------------------------------------------------------
263
-
264
- _COOPERATIVE_ACTIONS = frozenset({"cooperate", "stag", "dove"})
265
-
266
-
267
- def _compute_coop_rate(obs: GameObservation) -> float:
268
- """Fraction of cooperative moves in an episode."""
269
- if not obs.history:
270
- return EVAL_ZERO_FLOAT
271
- total = len(obs.history)
272
- count = _ZERO
273
- for rnd in obs.history:
274
- if rnd.player_action in _COOPERATIVE_ACTIONS:
275
- count += _ONE
276
- return count / total