File size: 5,066 Bytes
a1190da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
#!/usr/bin/env python3
"""
Minimal PPO test - tests if TRL PPO works with custom reward model.
Uses base GPT-2 (no LoRA) for simplicity.
"""
import os
os.environ['TRL_EXPERIMENTAL_SILENCE'] = '1'
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "classes"))
import torch
import torch.nn as nn
import numpy as np
print("=" * 60)
print("MINIMAL PPO TEST")
print("=" * 60)
# Test 1: TRL imports
print("\n[1] Testing TRL imports...")
import trl
print(f" TRL version: {trl.__version__}")
from trl.experimental.ppo import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
print(" [OK] PPO modules imported")
# Test 2: Load base tokenizer and model
print("\n[2] Loading base GPT-2...")
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
print(" [OK] Tokenizer loaded")
base_model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float32)
print(" [OK] Base model loaded")
# Test 3: Create custom reward model
print("\n[3] Creating custom reward model...")
class SequenceClassifierOutput:
def __init__(self, logits):
self.logits = logits
class SimpleRewardModel(nn.Module):
"""Simple reward model that returns random scores for testing."""
def __init__(self, tokenizer):
super().__init__()
self.tokenizer = tokenizer
self.config = type('Config', (), {'pad_token_id': tokenizer.pad_token_id})()
self.dummy = nn.Parameter(torch.zeros(1), requires_grad=False)
def forward(self, input_ids, attention_mask=None, **kwargs):
batch_size = input_ids.shape[0]
# Return random rewards for testing
rewards = torch.rand(batch_size, 1)
return SequenceClassifierOutput(logits=rewards)
device = torch.device("cpu") # Use CPU for testing
reward_model = SimpleRewardModel(tokenizer).to(device)
print(" [OK] Custom reward model created")
# Test forward pass
test_ids = tokenizer("test input", return_tensors="pt")["input_ids"]
output = reward_model(test_ids)
print(f" [OK] Forward pass works, logits shape: {output.logits.shape}")
# Test 4: PPOConfig
print("\n[4] Creating PPOConfig...")
try:
ppo_config = PPOConfig(
output_dir="./output/ppo_test",
learning_rate=1e-5,
per_device_train_batch_size=2,
total_episodes=4,
num_ppo_epochs=1,
response_length=20,
report_to=None,
use_cpu=True, # Required for CPU-only systems
bf16=False,
)
print(" [OK] PPOConfig created")
except Exception as e:
print(f" [FAIL] PPOConfig: {e}")
import traceback
traceback.print_exc()
# Test 5: Create models with value head
print("\n[5] Creating models with value heads...")
try:
from transformers import GenerationConfig
# Load models directly from pretrained string (not from model object)
policy_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
value_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
# Add generation_config (required by PPOTrainer)
gen_config = GenerationConfig.from_pretrained("gpt2")
policy_model.generation_config = gen_config
ref_model.generation_config = gen_config
value_model.generation_config = gen_config
# Add base_model_prefix if missing
if not hasattr(policy_model, 'base_model_prefix'):
policy_model.base_model_prefix = 'transformer'
ref_model.base_model_prefix = 'transformer'
value_model.base_model_prefix = 'transformer'
print(" [OK] Models with value heads created")
except Exception as e:
print(f" [FAIL] {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# Test 6: Create dataset
print("\n[6] Creating training dataset...")
from datasets import Dataset
prompt = '{"vars": ["x_1"], "ops": ["+", "sin"], "cons": null, "expr": "'
train_dataset = Dataset.from_dict({"query": [prompt] * 4})
print(f" [OK] Dataset with {len(train_dataset)} samples")
# Test 7: Create PPOTrainer
print("\n[7] Creating PPOTrainer...")
try:
ppo_trainer = PPOTrainer(
args=ppo_config,
processing_class=tokenizer,
model=policy_model,
ref_model=ref_model,
reward_model=reward_model,
value_model=value_model,
train_dataset=train_dataset,
)
print(" [OK] PPOTrainer created!")
# Test 8: Try a training step
print("\n[8] Testing training step...")
try:
ppo_trainer.train()
print(" [OK] Training completed!")
except Exception as e:
print(f" [FAIL] Training failed: {e}")
import traceback
traceback.print_exc()
except Exception as e:
print(f" [FAIL] PPOTrainer creation failed: {e}")
import traceback
traceback.print_exc()
print("\n" + "=" * 60)
print("TEST COMPLETE")
print("=" * 60)
|