Spaces:
Running
Running
| import argparse | |
| import json | |
| import random | |
| from typing import Any, Dict, List | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from datasets import Dataset | |
| from transformers import AutoTokenizer | |
| from trl import GRPOConfig, GRPOTrainer | |
| from reward_model import DiplomacyRewardModel, score_state | |
| def build_prompt(state_text: str, power_name: str) -> str: | |
| """Construct the text prompt given the current state description.""" | |
| return ( | |
| "You are a Diplomacy negotiation agent. Your goal is to maximize long-term strategic advantage " | |
| f"for {power_name} by capturing and holding supply centers while avoiding elimination.\n\n" | |
| "Current game state:\n\n" | |
| f"{state_text}\n\n" | |
| f"What is your next strategic move as {power_name}? " | |
| "Respond with a concise description of your intent and planned orders." | |
| ) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="GRPO training loop for Diplomacy negotiation using TRL.") | |
| parser.add_argument( | |
| "--lr", | |
| type=float, | |
| default=1e-5, | |
| help="Learning rate for GRPO (default: 1e-5).", | |
| ) | |
| parser.add_argument( | |
| "--reward_model_path", | |
| type=str, | |
| default="reward_model.pt", | |
| help="Path to trained reward model weights (default: reward_model.pt).", | |
| ) | |
| parser.add_argument( | |
| "--episodes", | |
| type=int, | |
| default=100, | |
| help="Approximate number of GRPO optimization steps (default: 100).", | |
| ) | |
| parser.add_argument( | |
| "--dataset_path", | |
| type=str, | |
| default="selfplay_states.json", | |
| help="Path to self-play states JSON file (default: selfplay_states.json).", | |
| ) | |
| args = parser.parse_args() | |
| # Device / GPU info | |
| print("torch.cuda.is_available():", torch.cuda.is_available()) | |
| if torch.cuda.is_available(): | |
| print("Using GPU:", torch.cuda.get_device_name(0)) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Policy model identifier (1B chat model, ungated). | |
| model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| print(f"Loading policy model: {model_name}") | |
| # Load trained reward model (DistilBERT-based scalar scorer). | |
| reward_model_path = args.reward_model_path.strip() | |
| print("Loading reward model from:", reward_model_path) | |
| reward_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
| reward_model = DiplomacyRewardModel().to(device) | |
| reward_model.load_state_dict( | |
| torch.load(reward_model_path, map_location=device) | |
| ) | |
| reward_model.eval() | |
| # Load self-play dataset and build GRPO prompts. | |
| print(f"Loading self-play dataset from: {args.dataset_path}") | |
| with open(args.dataset_path, "r") as f: | |
| raw_states: List[Dict[str, Any]] = json.load(f) | |
| prompts: List[str] = [] | |
| base_rewards: List[float] = [] | |
| for ex in raw_states: | |
| state_text = ex.get("state_text", "") | |
| power = ex.get("power", "ENGLAND") | |
| r = float(ex.get("reward", 0.0)) | |
| prompts.append(build_prompt(state_text, power)) | |
| base_rewards.append(r) | |
| dataset = Dataset.from_dict({"prompt": prompts, "reward": base_rewards}) | |
| print(f"Dataset size: {len(dataset)} examples") | |
| # Custom reward function for GRPO: reward model + repetition & length penalties. | |
| def reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]: | |
| """ | |
| Score each completion with reward model, then apply repetition and length penalties. | |
| """ | |
| scores: List[float] = [] | |
| for completion in completions: | |
| text = completion.strip() | |
| words = text.split() | |
| # Reward model score | |
| inputs = reward_tokenizer( | |
| text, | |
| return_tensors="pt", | |
| max_length=128, | |
| truncation=True, | |
| padding=True, | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| hidden = reward_model.encoder(**inputs).last_hidden_state[:, 0, :] | |
| rm_score = reward_model.head(hidden).squeeze().item() | |
| # Repetition penalty — unique words ratio | |
| if len(words) > 0: | |
| unique_ratio = len(set(words)) / len(words) | |
| else: | |
| unique_ratio = 0.0 | |
| if unique_ratio < 0.3: | |
| repetition_penalty = -2.0 | |
| elif unique_ratio < 0.5: | |
| repetition_penalty = -0.5 | |
| else: | |
| repetition_penalty = 0.0 | |
| # Penalty for very short completions | |
| length_penalty = -0.5 if len(words) < 5 else 0.0 | |
| combined = float(rm_score) + repetition_penalty + length_penalty | |
| scores.append(torch.tensor(combined, dtype=torch.float32)) | |
| return scores | |
| # GRPO configuration: small batch, multiple generations per prompt. | |
| grpo_config = GRPOConfig( | |
| output_dir="grpo_output", | |
| learning_rate=args.lr, | |
| per_device_train_batch_size=4, | |
| num_generations=4, | |
| max_completion_length=64, | |
| max_steps=args.episodes, | |
| logging_steps=1, | |
| save_steps=100, | |
| report_to="none", | |
| ) | |
| trainer = GRPOTrainer( | |
| model=model_name, | |
| reward_funcs=reward_fn, | |
| args=grpo_config, | |
| train_dataset=dataset, | |
| ) | |
| print("Starting GRPO training...") | |
| trainer.train() | |
| # Extract reward metrics: prefer internal _metrics, fallback to log_history. | |
| try: | |
| rewards = trainer._metrics["train"]["reward"] | |
| print(f"Total reward datapoints from _metrics: {len(rewards)}") | |
| except (KeyError, AttributeError): | |
| rewards = [] | |
| if not rewards: | |
| rewards = [] | |
| for entry in trainer.state.log_history: | |
| for key in ["reward", "rewards/reward_fn/mean"]: | |
| if key in entry: | |
| rewards.append(entry[key]) | |
| break | |
| print(f"Total reward datapoints: {len(rewards)}") | |
| if rewards: | |
| plt.figure(figsize=(12, 5)) | |
| plt.plot(rewards, alpha=0.6, label="Step Reward") | |
| if len(rewards) >= 10: | |
| window = max(len(rewards) // 10, 1) | |
| moving_avg = [ | |
| sum(rewards[max(0, i - window) : i + 1]) / min(i + 1, window) | |
| for i in range(len(rewards)) | |
| ] | |
| plt.plot(moving_avg, linewidth=2, label="Moving Average") | |
| plt.xlabel("Training Step") | |
| plt.ylabel("Reward") | |
| plt.title("Play-gent GRPO Reward Curve") | |
| plt.legend() | |
| plt.tight_layout() | |
| plt.savefig("ppo_reward_curve.png") | |
| print( | |
| f"Curve saved. Final: {rewards[-1]:.3f}, " | |
| f"Best: {max(rewards):.3f}" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |