| |
| """ |
| Minimal PPO test - tests if TRL PPO works with custom reward model. |
| Uses base GPT-2 (no LoRA) for simplicity. |
| """ |
|
|
| import os |
| os.environ['TRL_EXPERIMENTAL_SILENCE'] = '1' |
|
|
| import sys |
| from pathlib import Path |
|
|
| PROJECT_ROOT = Path(__file__).parent.parent |
| sys.path.insert(0, str(PROJECT_ROOT)) |
| sys.path.insert(0, str(PROJECT_ROOT / "classes")) |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
|
|
| print("=" * 60) |
| print("MINIMAL PPO TEST") |
| print("=" * 60) |
|
|
| |
| print("\n[1] Testing TRL imports...") |
| import trl |
| print(f" TRL version: {trl.__version__}") |
|
|
| from trl.experimental.ppo import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead |
| print(" [OK] PPO modules imported") |
|
|
| |
| print("\n[2] Loading base GPT-2...") |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| tokenizer.pad_token = tokenizer.eos_token |
| print(" [OK] Tokenizer loaded") |
|
|
| base_model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float32) |
| print(" [OK] Base model loaded") |
|
|
| |
| print("\n[3] Creating custom reward model...") |
|
|
| class SequenceClassifierOutput: |
| def __init__(self, logits): |
| self.logits = logits |
|
|
| class SimpleRewardModel(nn.Module): |
| """Simple reward model that returns random scores for testing.""" |
|
|
| def __init__(self, tokenizer): |
| super().__init__() |
| self.tokenizer = tokenizer |
| self.config = type('Config', (), {'pad_token_id': tokenizer.pad_token_id})() |
| self.dummy = nn.Parameter(torch.zeros(1), requires_grad=False) |
|
|
| def forward(self, input_ids, attention_mask=None, **kwargs): |
| batch_size = input_ids.shape[0] |
| |
| rewards = torch.rand(batch_size, 1) |
| return SequenceClassifierOutput(logits=rewards) |
|
|
| device = torch.device("cpu") |
| reward_model = SimpleRewardModel(tokenizer).to(device) |
| print(" [OK] Custom reward model created") |
|
|
| |
| test_ids = tokenizer("test input", return_tensors="pt")["input_ids"] |
| output = reward_model(test_ids) |
| print(f" [OK] Forward pass works, logits shape: {output.logits.shape}") |
|
|
| |
| print("\n[4] Creating PPOConfig...") |
| try: |
| ppo_config = PPOConfig( |
| output_dir="./output/ppo_test", |
| learning_rate=1e-5, |
| per_device_train_batch_size=2, |
| total_episodes=4, |
| num_ppo_epochs=1, |
| response_length=20, |
| report_to=None, |
| use_cpu=True, |
| bf16=False, |
| ) |
| print(" [OK] PPOConfig created") |
| except Exception as e: |
| print(f" [FAIL] PPOConfig: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| |
| print("\n[5] Creating models with value heads...") |
| try: |
| from transformers import GenerationConfig |
|
|
| |
| policy_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") |
| ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") |
| value_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") |
|
|
| |
| gen_config = GenerationConfig.from_pretrained("gpt2") |
| policy_model.generation_config = gen_config |
| ref_model.generation_config = gen_config |
| value_model.generation_config = gen_config |
|
|
| |
| if not hasattr(policy_model, 'base_model_prefix'): |
| policy_model.base_model_prefix = 'transformer' |
| ref_model.base_model_prefix = 'transformer' |
| value_model.base_model_prefix = 'transformer' |
|
|
| print(" [OK] Models with value heads created") |
| except Exception as e: |
| print(f" [FAIL] {e}") |
| import traceback |
| traceback.print_exc() |
| sys.exit(1) |
|
|
| |
| print("\n[6] Creating training dataset...") |
| from datasets import Dataset |
|
|
| prompt = '{"vars": ["x_1"], "ops": ["+", "sin"], "cons": null, "expr": "' |
| train_dataset = Dataset.from_dict({"query": [prompt] * 4}) |
| print(f" [OK] Dataset with {len(train_dataset)} samples") |
|
|
| |
| print("\n[7] Creating PPOTrainer...") |
| try: |
| ppo_trainer = PPOTrainer( |
| args=ppo_config, |
| processing_class=tokenizer, |
| model=policy_model, |
| ref_model=ref_model, |
| reward_model=reward_model, |
| value_model=value_model, |
| train_dataset=train_dataset, |
| ) |
| print(" [OK] PPOTrainer created!") |
|
|
| |
| print("\n[8] Testing training step...") |
| try: |
| ppo_trainer.train() |
| print(" [OK] Training completed!") |
| except Exception as e: |
| print(f" [FAIL] Training failed: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| except Exception as e: |
| print(f" [FAIL] PPOTrainer creation failed: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| print("\n" + "=" * 60) |
| print("TEST COMPLETE") |
| print("=" * 60) |
|
|