jtowarek commited on
Commit
92dbd69
·
verified ·
1 Parent(s): 4550628

Delete train/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train/train.py +0 -457
train/train.py DELETED
@@ -1,457 +0,0 @@
1
- """KantBench GRPO Training Script.
2
-
3
- Trains a language model to play 2-player game theory games optimally
4
- using Group Relative Policy Optimization (GRPO) via TRL.
5
-
6
- The KantBench environment runs as a remote OpenEnv server (HF Space):
7
- - Each GRPO completion is a single move
8
- - The reward function plays a FULL multi-round episode using that move
9
- as the agent's consistent strategy via the OpenEnv client
10
- - The composite reward (payoff + cooperation + Pareto efficiency + fairness)
11
- becomes the GRPO signal
12
-
13
- Supports the full KantBench game library including:
14
- - 90+ base 2-player games and 3 N-player games
15
- - 9 pre-registered meta-games (rule_proposal, rule_signal, gossip)
16
- - Dynamic variant composition (cheap_talk, exit, binding_commitment,
17
- constitutional, proposer_responder, noisy_actions, noisy_payoffs)
18
-
19
- Usage:
20
- python -m train.train --model Qwen/Qwen2.5-7B-Instruct --max-steps 200
21
- """
22
-
23
- from __future__ import annotations
24
-
25
- import argparse
26
- import logging
27
- import random
28
- from typing import Any, List
29
-
30
- import torch
31
- from datasets import Dataset
32
- from trl import GRPOConfig, GRPOTrainer
33
- from transformers import AutoTokenizer
34
-
35
- from common.games import GAMES
36
- from common.strategies import STRATEGIES as STRATEGY_REGISTRY
37
- from spaces.kant.client import KantBenchEnv
38
- from spaces.kant.models import KantBenchAction, KantBenchObservation
39
- from train.agent import parse_action
40
- from train.rewards import episode_reward
41
- from train.splits import get_train_eval_split
42
-
43
- logger = logging.getLogger(__name__)
44
-
45
- # ---------------------------------------------------------------------------
46
- # Config
47
- # ---------------------------------------------------------------------------
48
-
49
- KANTBENCH_URL = "https://openenv-community-kantbench.hf.space"
50
-
51
- SYSTEM_PROMPT = (
52
- "You are playing a game-theory game. Analyse the situation and choose "
53
- "the best action. Respond with ONLY the action name, nothing else."
54
- )
55
-
56
- # Variants that can be dynamically composed on top of base games.
57
- # These are applied server-side via the variant= reset parameter.
58
- TRAINABLE_VARIANTS = [
59
- "cheap_talk",
60
- "exit",
61
- "binding_commitment",
62
- "constitutional",
63
- "noisy_actions",
64
- "noisy_payoffs",
65
- "rule_proposal",
66
- "rule_signal",
67
- "gossip",
68
- ]
69
-
70
- # Base games suitable for variant composition (2-player matrix games).
71
- VARIANT_BASE_GAMES = [
72
- "prisoners_dilemma",
73
- "stag_hunt",
74
- "hawk_dove",
75
- ]
76
-
77
- # Fraction of dataset samples that use dynamic variant composition.
78
- VARIANT_FRACTION = 0.3
79
-
80
-
81
- # ---------------------------------------------------------------------------
82
- # Helpers to bridge KantBenchObservation -> training code
83
- # ---------------------------------------------------------------------------
84
-
85
-
86
- def _obs_cooperation_rate(obs: KantBenchObservation) -> float:
87
- """Compute cooperation rate from a KantBenchObservation's history."""
88
- if not obs.history:
89
- return 0.0
90
- coop_actions = {"cooperate", "stag", "dove", "contribute"}
91
- coop_count = sum(
92
- 1 for h in obs.history
93
- if any(ca in h.get("your_move", "") for ca in coop_actions)
94
- )
95
- return coop_count / len(obs.history)
96
-
97
-
98
- def _build_prompt(obs: KantBenchObservation) -> str:
99
- """Build a structured prompt from a KantBenchObservation.
100
-
101
- Mirrors PromptBuilder.build() but works with the OpenEnv client's
102
- observation format.
103
- """
104
- sections: list[str] = []
105
-
106
- # Game section
107
- sections.append(
108
- f"[Game]\n{obs.game_name}\n{obs.game_description}"
109
- )
110
-
111
- # History section
112
- if obs.history:
113
- history_lines: list[str] = []
114
- for h in obs.history[-5:]: # Last 5 rounds
115
- line = (
116
- f"Round {h.get('round', '?')}"
117
- f" | You played: {h.get('your_move', '?')}"
118
- f" | Opponent played: {h.get('opponent_move', '?')}"
119
- f" | Your payoff: {h.get('your_payoff', '?')}"
120
- f" | Opp payoff: {h.get('opponent_payoff', '?')}"
121
- )
122
- history_lines.append(line)
123
- sections.append("[History]\n" + "\n".join(history_lines))
124
-
125
- # Scores section
126
- sections.append(
127
- f"[Scores]\nYour score: {obs.cumulative_score}"
128
- f"\nRound: {obs.round_number} of {obs.max_rounds}"
129
- )
130
-
131
- # Available actions
132
- action_lines = [f"- {a}" for a in obs.available_moves]
133
- sections.append("[Available Actions]\n" + "\n".join(action_lines))
134
-
135
- # Instruction
136
- sections.append(f"[Instruction]\n{SYSTEM_PROMPT}")
137
-
138
- return "\n\n".join(sections)
139
-
140
- # ---------------------------------------------------------------------------
141
- # Dataset generation using PromptBuilder
142
- # ---------------------------------------------------------------------------
143
-
144
-
145
- def build_dataset(
146
- base_url: str,
147
- n_samples: int = 1000,
148
- games: list[str] | None = None,
149
- strategies: list[str] | None = None,
150
- variant_fraction: float = VARIANT_FRACTION,
151
- ) -> Dataset:
152
- """Generate diverse game theory prompts for GRPO training.
153
-
154
- Connects to the KantBench OpenEnv server to generate real observations,
155
- then builds structured prompts from diverse game states.
156
-
157
- A fraction of samples use dynamic variant composition (cheap_talk,
158
- constitutional, gossip, etc.) to train on meta-gaming scenarios.
159
- """
160
- game_keys = games or list(GAMES.keys())
161
- strat_names = strategies or list(STRATEGY_REGISTRY.keys())
162
- samples = []
163
-
164
- with KantBenchEnv(base_url=base_url) as env:
165
- attempts = 0
166
- while len(samples) < n_samples:
167
- attempts += 1
168
-
169
- # Decide whether to use a variant
170
- use_variant = random.random() < variant_fraction
171
- if use_variant:
172
- game_key = random.choice(VARIANT_BASE_GAMES)
173
- variant = random.choice(TRAINABLE_VARIANTS)
174
- else:
175
- game_key = random.choice(game_keys)
176
- variant = None
177
-
178
- strategy = random.choice(strat_names)
179
-
180
- try:
181
- # Reset env — pass variant for dynamic composition
182
- reset_kwargs = {"game": game_key, "strategy": strategy}
183
- if variant:
184
- reset_kwargs["variant"] = variant
185
-
186
- result = env.reset(**reset_kwargs)
187
- obs = result.observation
188
-
189
- # Play 0..N-1 random rounds to create diverse game states
190
- max_rounds = obs.max_rounds
191
- rounds_to_play = random.randint(0, max(max_rounds - 1, 0))
192
- for _ in range(rounds_to_play):
193
- move = random.choice(obs.available_moves)
194
- result = env.step(KantBenchAction(move=move))
195
- obs = result.observation
196
- if result.done:
197
- break
198
-
199
- if result.done:
200
- # Replay without filling all rounds
201
- result = env.reset(**reset_kwargs)
202
- obs = result.observation
203
-
204
- prompt = _build_prompt(obs)
205
-
206
- samples.append({
207
- "prompt": prompt,
208
- "game_key": game_key,
209
- "strategy": strategy,
210
- "variant": variant or "",
211
- "available_moves": list(obs.available_moves),
212
- "rounds_remaining": obs.max_rounds - obs.round_number,
213
- })
214
- except (RuntimeError, ConnectionError, Exception) as exc:
215
- logger.debug(
216
- "Skipping %s/%s (variant=%s): %s",
217
- game_key, strategy, variant, exc,
218
- )
219
- continue
220
-
221
- return Dataset.from_list(samples)
222
-
223
-
224
- # ---------------------------------------------------------------------------
225
- # Reward function — full episode rollout
226
- # ---------------------------------------------------------------------------
227
-
228
-
229
- def make_reward_fn(base_url: str):
230
- """Returns a GRPO reward function that plays full episodes via OpenEnv.
231
-
232
- For each completion:
233
- 1. Parse the move from the LLM output
234
- 2. Reset the KantBench server with the correct game/strategy/variant
235
- 3. Play the FULL episode using the parsed move as a consistent strategy
236
- 4. Compute composite reward: payoff + cooperation + Pareto + fairness
237
- """
238
- env = KantBenchEnv(base_url=base_url)
239
- env.connect()
240
-
241
- def reward_fn(
242
- completions: list[str],
243
- prompts: list[str],
244
- **kwargs: Any,
245
- ) -> list[float]:
246
- rewards = []
247
- game_keys = kwargs.get("game_key", ["prisoners_dilemma"] * len(completions))
248
- strategies = kwargs.get("strategy", ["tit_for_tat"] * len(completions))
249
- variants = kwargs.get("variant", [""] * len(completions))
250
- available_moves_batch = kwargs.get(
251
- "available_moves", [["cooperate", "defect"]] * len(completions)
252
- )
253
-
254
- for i, (completion, game_key, strategy, variant, moves) in enumerate(zip(
255
- completions, game_keys, strategies, variants, available_moves_batch
256
- )):
257
- # Parse move from LLM output
258
- action_str = parse_action(completion.strip(), moves)
259
-
260
- # Log first few completions per batch for debugging
261
- if i < 3:
262
- logger.info(
263
- "Completion [%d] game=%s moves=%s -> parsed=%s | raw=%r",
264
- i, game_key, moves, action_str, completion[:200],
265
- )
266
-
267
- try:
268
- # Play a full episode using this move as a consistent strategy
269
- reset_kwargs = {"game": game_key, "strategy": strategy}
270
- if variant:
271
- reset_kwargs["variant"] = variant
272
-
273
- result = env.reset(**reset_kwargs)
274
- while not result.done:
275
- result = env.step(KantBenchAction(move=action_str))
276
-
277
- obs = result.observation
278
-
279
- # Compute cooperation rate from observation history
280
- coop_rate = _obs_cooperation_rate(obs)
281
-
282
- # Composite reward from the reward module
283
- # opponent_score not directly available in KantBenchObservation,
284
- # approximate from history
285
- opp_score = sum(
286
- h.get("opponent_payoff", 0.0) for h in obs.history
287
- )
288
- reward = episode_reward(
289
- player_score=obs.cumulative_score,
290
- opponent_score=opp_score,
291
- cooperation_rate=coop_rate,
292
- total_rounds=obs.round_number,
293
- )
294
- rewards.append(reward)
295
-
296
- except (ValueError, KeyError, RuntimeError, ConnectionError) as exc:
297
- logger.debug("Reward error for %s/%s: %s", game_key, action_str, exc)
298
- rewards.append(-1.0)
299
-
300
- return rewards
301
-
302
- return reward_fn
303
-
304
-
305
- def format_reward_fn(
306
- completions: list[str],
307
- prompts: list[str],
308
- **kwargs: Any,
309
- ) -> list[float]:
310
- """Reward function that encourages concise, exact-match action output.
311
-
312
- Returns 1.0 for exact match, 0.5 for case-insensitive, 0.1 for substring,
313
- -0.5 for random fallback (action not found in output).
314
- """
315
- rewards = []
316
- available_moves_batch = kwargs.get(
317
- "available_moves", [["cooperate", "defect"]] * len(completions)
318
- )
319
- for completion, moves in zip(completions, available_moves_batch):
320
- stripped = completion.strip()
321
- if stripped in moves:
322
- rewards.append(1.0)
323
- elif stripped.lower() in [m.lower() for m in moves]:
324
- rewards.append(0.5)
325
- elif any(m.lower() in stripped.lower() for m in moves):
326
- rewards.append(0.1)
327
- else:
328
- rewards.append(-0.5)
329
- return rewards
330
-
331
-
332
- # ---------------------------------------------------------------------------
333
- # Main
334
- # ---------------------------------------------------------------------------
335
-
336
-
337
- def parse_args():
338
- p = argparse.ArgumentParser(description="KantBench GRPO Training")
339
- p.add_argument("--model", default="Qwen/Qwen2.5-7B-Instruct")
340
- p.add_argument("--output-dir", default="./kantbench-grpo")
341
- p.add_argument("--env-url", default=KANTBENCH_URL,
342
- help="KantBench OpenEnv server URL")
343
- p.add_argument("--episodes", type=int, default=1000, help="Training dataset size")
344
- p.add_argument("--num-generations", type=int, default=8, help="GRPO group size")
345
- p.add_argument("--batch-size", type=int, default=4)
346
- p.add_argument("--grad-accum", type=int, default=4)
347
- p.add_argument("--lr", type=float, default=3e-6)
348
- p.add_argument("--max-steps", type=int, default=500)
349
- p.add_argument("--report-to", default="wandb", help="wandb, tensorboard, or none")
350
- p.add_argument("--push-to-hub", action="store_true")
351
- p.add_argument("--hub-model-id", default="jtowarek/kantbench-qwen2.5-7b")
352
- p.add_argument("--use-train-split", action="store_true",
353
- help="Use stratified train/eval split (eval games held out)")
354
- p.add_argument("--variant-fraction", type=float, default=VARIANT_FRACTION,
355
- help="Fraction of samples using dynamic variant composition")
356
- p.add_argument("--resume-from-checkpoint", type=str, default=None,
357
- help="Path to checkpoint or 'latest' to resume training")
358
- return p.parse_args()
359
-
360
-
361
- def main():
362
- args = parse_args()
363
- logging.basicConfig(level=logging.INFO)
364
-
365
- print(f"Loading model: {args.model}")
366
- print(f"Output: {args.output_dir}")
367
- print(f"OpenEnv server: {args.env_url}")
368
-
369
- tokenizer = AutoTokenizer.from_pretrained(args.model)
370
- if tokenizer.pad_token is None:
371
- tokenizer.pad_token = tokenizer.eos_token
372
-
373
- # Optionally use stratified train/eval split
374
- train_games = None
375
- if args.use_train_split:
376
- train_set, eval_set = get_train_eval_split()
377
- train_games = sorted(train_set)
378
- print(f"Using stratified split: {len(train_games)} train, {len(eval_set)} eval games")
379
-
380
- dataset = build_dataset(
381
- args.env_url, args.episodes, games=train_games,
382
- variant_fraction=args.variant_fraction,
383
- )
384
- variant_count = sum(1 for v in dataset["variant"] if v)
385
- print(f"Dataset: {len(dataset)} prompts across {len(GAMES)} games")
386
- print(f" Variant samples: {variant_count} ({variant_count*100//max(len(dataset),1)}%)")
387
-
388
- # Format prompts with chat template
389
- def format_prompt(example):
390
- messages = [
391
- {"role": "system", "content": SYSTEM_PROMPT},
392
- {"role": "user", "content": example["prompt"]},
393
- ]
394
- return {
395
- "prompt": tokenizer.apply_chat_template(
396
- messages, tokenize=False, add_generation_prompt=True
397
- )
398
- }
399
-
400
- dataset = dataset.map(format_prompt)
401
-
402
- reward_fn = make_reward_fn(args.env_url)
403
-
404
- config = GRPOConfig(
405
- output_dir=args.output_dir,
406
- num_generations=args.num_generations,
407
- max_completion_length=32,
408
- per_device_train_batch_size=args.batch_size,
409
- gradient_accumulation_steps=args.grad_accum,
410
- learning_rate=args.lr,
411
- lr_scheduler_type="constant_with_warmup",
412
- warmup_steps=50,
413
- max_steps=args.max_steps,
414
- logging_steps=10,
415
- save_steps=100,
416
- save_total_limit=2,
417
- bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8,
418
- fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
419
- report_to=args.report_to,
420
- push_to_hub=args.push_to_hub,
421
- hub_model_id=args.hub_model_id if args.push_to_hub else None,
422
- # Stop generation at newline token to enforce single-action output
423
- generation_kwargs={"temperature": 0.7},
424
- )
425
-
426
- # Add newline token as an extra EOS so generation stops after one line
427
- newline_token_id = tokenizer.encode("\n", add_special_tokens=False)
428
- if newline_token_id:
429
- config.generation_kwargs["eos_token_id"] = [
430
- tokenizer.eos_token_id, newline_token_id[0],
431
- ]
432
-
433
- trainer = GRPOTrainer(
434
- model=args.model,
435
- reward_funcs=[reward_fn, format_reward_fn],
436
- args=config,
437
- train_dataset=dataset,
438
- processing_class=tokenizer,
439
- )
440
-
441
- resume_ckpt = args.resume_from_checkpoint
442
- if resume_ckpt == "latest":
443
- resume_ckpt = True # Trainer auto-finds latest checkpoint in output_dir
444
-
445
- print("Starting GRPO training...")
446
- print(f" Reward: composite (payoff + cooperation + Pareto + fairness)")
447
- print(f" Episode: full multi-round rollout via OpenEnv @ {args.env_url}")
448
- print(f" Variants: {args.variant_fraction*100:.0f}% of samples use dynamic composition")
449
- if resume_ckpt:
450
- print(f" Resuming from checkpoint: {resume_ckpt}")
451
- trainer.train(resume_from_checkpoint=resume_ckpt)
452
- trainer.save_model(args.output_dir)
453
- print(f"Done. Model saved to {args.output_dir}")
454
-
455
-
456
- if __name__ == "__main__":
457
- main()