File size: 5,066 Bytes
a1190da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#!/usr/bin/env python3
"""
Minimal PPO test - tests if TRL PPO works with custom reward model.
Uses base GPT-2 (no LoRA) for simplicity.
"""

import os
os.environ['TRL_EXPERIMENTAL_SILENCE'] = '1'

import sys
from pathlib import Path

PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "classes"))

import torch
import torch.nn as nn
import numpy as np

print("=" * 60)
print("MINIMAL PPO TEST")
print("=" * 60)

# Test 1: TRL imports
print("\n[1] Testing TRL imports...")
import trl
print(f"    TRL version: {trl.__version__}")

from trl.experimental.ppo import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
print("    [OK] PPO modules imported")

# Test 2: Load base tokenizer and model
print("\n[2] Loading base GPT-2...")
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
print("    [OK] Tokenizer loaded")

base_model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float32)
print("    [OK] Base model loaded")

# Test 3: Create custom reward model
print("\n[3] Creating custom reward model...")

class SequenceClassifierOutput:
    def __init__(self, logits):
        self.logits = logits

class SimpleRewardModel(nn.Module):
    """Simple reward model that returns random scores for testing."""

    def __init__(self, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        self.config = type('Config', (), {'pad_token_id': tokenizer.pad_token_id})()
        self.dummy = nn.Parameter(torch.zeros(1), requires_grad=False)

    def forward(self, input_ids, attention_mask=None, **kwargs):
        batch_size = input_ids.shape[0]
        # Return random rewards for testing
        rewards = torch.rand(batch_size, 1)
        return SequenceClassifierOutput(logits=rewards)

device = torch.device("cpu")  # Use CPU for testing
reward_model = SimpleRewardModel(tokenizer).to(device)
print("    [OK] Custom reward model created")

# Test forward pass
test_ids = tokenizer("test input", return_tensors="pt")["input_ids"]
output = reward_model(test_ids)
print(f"    [OK] Forward pass works, logits shape: {output.logits.shape}")

# Test 4: PPOConfig
print("\n[4] Creating PPOConfig...")
try:
    ppo_config = PPOConfig(
        output_dir="./output/ppo_test",
        learning_rate=1e-5,
        per_device_train_batch_size=2,
        total_episodes=4,
        num_ppo_epochs=1,
        response_length=20,
        report_to=None,
        use_cpu=True,  # Required for CPU-only systems
        bf16=False,
    )
    print("    [OK] PPOConfig created")
except Exception as e:
    print(f"    [FAIL] PPOConfig: {e}")
    import traceback
    traceback.print_exc()

# Test 5: Create models with value head
print("\n[5] Creating models with value heads...")
try:
    from transformers import GenerationConfig

    # Load models directly from pretrained string (not from model object)
    policy_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
    ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
    value_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")

    # Add generation_config (required by PPOTrainer)
    gen_config = GenerationConfig.from_pretrained("gpt2")
    policy_model.generation_config = gen_config
    ref_model.generation_config = gen_config
    value_model.generation_config = gen_config

    # Add base_model_prefix if missing
    if not hasattr(policy_model, 'base_model_prefix'):
        policy_model.base_model_prefix = 'transformer'
        ref_model.base_model_prefix = 'transformer'
        value_model.base_model_prefix = 'transformer'

    print("    [OK] Models with value heads created")
except Exception as e:
    print(f"    [FAIL] {e}")
    import traceback
    traceback.print_exc()
    sys.exit(1)

# Test 6: Create dataset
print("\n[6] Creating training dataset...")
from datasets import Dataset

prompt = '{"vars": ["x_1"], "ops": ["+", "sin"], "cons": null, "expr": "'
train_dataset = Dataset.from_dict({"query": [prompt] * 4})
print(f"    [OK] Dataset with {len(train_dataset)} samples")

# Test 7: Create PPOTrainer
print("\n[7] Creating PPOTrainer...")
try:
    ppo_trainer = PPOTrainer(
        args=ppo_config,
        processing_class=tokenizer,
        model=policy_model,
        ref_model=ref_model,
        reward_model=reward_model,
        value_model=value_model,
        train_dataset=train_dataset,
    )
    print("    [OK] PPOTrainer created!")

    # Test 8: Try a training step
    print("\n[8] Testing training step...")
    try:
        ppo_trainer.train()
        print("    [OK] Training completed!")
    except Exception as e:
        print(f"    [FAIL] Training failed: {e}")
        import traceback
        traceback.print_exc()

except Exception as e:
    print(f"    [FAIL] PPOTrainer creation failed: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "=" * 60)
print("TEST COMPLETE")
print("=" * 60)