gpt2_large_prefix_682k / scripts /test_ppo_minimal.py
augustocsc's picture
GPT-2 Large trained on prefix dataset (682K)
28b769b verified
#!/usr/bin/env python3
"""
Minimal PPO test - tests if TRL PPO works with custom reward model.
Uses base GPT-2 (no LoRA) for simplicity.
"""
import os
os.environ['TRL_EXPERIMENTAL_SILENCE'] = '1'
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "classes"))
import torch
import torch.nn as nn
import numpy as np
print("=" * 60)
print("MINIMAL PPO TEST")
print("=" * 60)
# Test 1: TRL imports
print("\n[1] Testing TRL imports...")
import trl
print(f" TRL version: {trl.__version__}")
from trl.experimental.ppo import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
print(" [OK] PPO modules imported")
# Test 2: Load base tokenizer and model
print("\n[2] Loading base GPT-2...")
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
print(" [OK] Tokenizer loaded")
base_model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float32)
print(" [OK] Base model loaded")
# Test 3: Create custom reward model
print("\n[3] Creating custom reward model...")
class SequenceClassifierOutput:
def __init__(self, logits):
self.logits = logits
class SimpleRewardModel(nn.Module):
"""Simple reward model that returns random scores for testing."""
def __init__(self, tokenizer):
super().__init__()
self.tokenizer = tokenizer
self.config = type('Config', (), {'pad_token_id': tokenizer.pad_token_id})()
self.dummy = nn.Parameter(torch.zeros(1), requires_grad=False)
def forward(self, input_ids, attention_mask=None, **kwargs):
batch_size = input_ids.shape[0]
# Return random rewards for testing
rewards = torch.rand(batch_size, 1)
return SequenceClassifierOutput(logits=rewards)
device = torch.device("cpu") # Use CPU for testing
reward_model = SimpleRewardModel(tokenizer).to(device)
print(" [OK] Custom reward model created")
# Test forward pass
test_ids = tokenizer("test input", return_tensors="pt")["input_ids"]
output = reward_model(test_ids)
print(f" [OK] Forward pass works, logits shape: {output.logits.shape}")
# Test 4: PPOConfig
print("\n[4] Creating PPOConfig...")
try:
ppo_config = PPOConfig(
output_dir="./output/ppo_test",
learning_rate=1e-5,
per_device_train_batch_size=2,
total_episodes=4,
num_ppo_epochs=1,
response_length=20,
report_to=None,
use_cpu=True, # Required for CPU-only systems
bf16=False,
)
print(" [OK] PPOConfig created")
except Exception as e:
print(f" [FAIL] PPOConfig: {e}")
import traceback
traceback.print_exc()
# Test 5: Create models with value head
print("\n[5] Creating models with value heads...")
try:
from transformers import GenerationConfig
# Load models directly from pretrained string (not from model object)
policy_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
value_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
# Add generation_config (required by PPOTrainer)
gen_config = GenerationConfig.from_pretrained("gpt2")
policy_model.generation_config = gen_config
ref_model.generation_config = gen_config
value_model.generation_config = gen_config
# Add base_model_prefix if missing
if not hasattr(policy_model, 'base_model_prefix'):
policy_model.base_model_prefix = 'transformer'
ref_model.base_model_prefix = 'transformer'
value_model.base_model_prefix = 'transformer'
print(" [OK] Models with value heads created")
except Exception as e:
print(f" [FAIL] {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# Test 6: Create dataset
print("\n[6] Creating training dataset...")
from datasets import Dataset
prompt = '{"vars": ["x_1"], "ops": ["+", "sin"], "cons": null, "expr": "'
train_dataset = Dataset.from_dict({"query": [prompt] * 4})
print(f" [OK] Dataset with {len(train_dataset)} samples")
# Test 7: Create PPOTrainer
print("\n[7] Creating PPOTrainer...")
try:
ppo_trainer = PPOTrainer(
args=ppo_config,
processing_class=tokenizer,
model=policy_model,
ref_model=ref_model,
reward_model=reward_model,
value_model=value_model,
train_dataset=train_dataset,
)
print(" [OK] PPOTrainer created!")
# Test 8: Try a training step
print("\n[8] Testing training step...")
try:
ppo_trainer.train()
print(" [OK] Training completed!")
except Exception as e:
print(f" [FAIL] Training failed: {e}")
import traceback
traceback.print_exc()
except Exception as e:
print(f" [FAIL] PPOTrainer creation failed: {e}")
import traceback
traceback.print_exc()
print("\n" + "=" * 60)
print("TEST COMPLETE")
print("=" * 60)