multipref-reward-model-qwen / train_ppo_distributed.py
Yuhan123's picture
Upload folder using huggingface_hub
a454c23 verified
#!/usr/bin/env python3
"""
PPO Training Script for MultiPref Dataset
Uses trained reward model for policy optimization with proper distributed training
Based on: https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py
"""
import os
from dataclasses import dataclass, field
from typing import Optional
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoModelForSequenceClassification,
HfArgumentParser,
)
from datasets import load_from_disk, Dataset
from trl import PPOTrainer, PPOConfig, ModelConfig
from accelerate import PartialState
@dataclass
class ScriptArguments:
"""Script arguments for PPO training"""
model_name: str = field(
default="meta-llama/Meta-Llama-3-8B-Instruct",
metadata={"help": "Base policy model name"}
)
reward_model_path: str = field(
default="./results/reward_model",
metadata={"help": "Path to trained reward model"}
)
dataset_path: str = field(
default="data/multipref_train.hf",
metadata={"help": "Path to training dataset"}
)
max_prompt_length: int = field(
default=512,
metadata={"help": "Maximum prompt length"}
)
max_new_tokens: int = field(
default=256,
metadata={"help": "Maximum new tokens to generate"}
)
temperature: float = field(
default=0.7,
metadata={"help": "Sampling temperature"}
)
top_p: float = field(
default=0.95,
metadata={"help": "Top-p sampling parameter"}
)
use_peft: bool = field(
default=False,
metadata={"help": "Whether to use PEFT for training"}
)
def prepare_dataset_for_ppo(dataset, tokenizer, max_prompt_length=512):
"""
Prepare the multipref dataset for PPO training
Extract unique prompts for policy to generate responses
"""
# Extract unique prompts to avoid duplicates
prompts_seen = set()
unique_examples = []
for example in dataset:
if example['prompt'] not in prompts_seen:
prompts_seen.add(example['prompt'])
unique_examples.append(example)
print(f"Original dataset size: {len(dataset)}")
print(f"Unique prompts: {len(unique_examples)}")
def format_prompt(example):
# Format prompt for Llama-3 instruction following
prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{example['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
return {"query": prompt}
# Convert to dataset and format
unique_dataset = Dataset.from_list(unique_examples)
formatted_dataset = unique_dataset.map(format_prompt)
# Tokenize queries
def tokenize(example):
tokens = tokenizer(
example["query"],
truncation=True,
max_length=max_prompt_length,
padding=False,
return_tensors=None,
)
example["input_ids"] = tokens["input_ids"]
example["query"] = tokenizer.decode(tokens["input_ids"])
return example
tokenized_dataset = formatted_dataset.map(tokenize, batched=False)
tokenized_dataset.set_format(type="torch")
return tokenized_dataset
def main():
parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig))
script_args, ppo_config, model_config = parser.parse_args_into_dataclasses()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path or script_args.model_name,
trust_remote_code=model_config.trust_remote_code,
)
# Add pad token for Llama models
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
# Load and prepare dataset with distributed-aware processing
with PartialState().local_main_process_first():
print("Loading and preparing dataset...")
dataset = load_from_disk(script_args.dataset_path)
train_dataset = prepare_dataset_for_ppo(
dataset,
tokenizer,
max_prompt_length=script_args.max_prompt_length
)
print(f"Prepared dataset size: {len(train_dataset)}")
if len(train_dataset) > 0:
print("Sample query:", train_dataset[0]['query'][:200] + "...")
# Load policy model
policy = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path or script_args.model_name,
trust_remote_code=model_config.trust_remote_code,
torch_dtype=model_config.torch_dtype,
)
# Load reward model
reward_model = AutoModelForSequenceClassification.from_pretrained(
script_args.reward_model_path,
num_labels=1,
trust_remote_code=model_config.trust_remote_code,
torch_dtype=model_config.torch_dtype,
)
# Load reward tokenizer (should be same as policy tokenizer)
reward_tokenizer = AutoTokenizer.from_pretrained(
script_args.reward_model_path,
trust_remote_code=model_config.trust_remote_code,
)
if reward_tokenizer.pad_token is None:
reward_tokenizer.pad_token = reward_tokenizer.eos_token
reward_tokenizer.pad_token_id = reward_tokenizer.eos_token_id
# PEFT configuration if requested
peft_config = None
if script_args.use_peft:
from peft import LoraConfig, TaskType
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)
# Create PPO trainer
ppo_trainer = PPOTrainer(
config=ppo_config,
processing_class=tokenizer,
policy=policy,
ref_policy=None, # Will be created automatically
reward_model=reward_model,
train_dataset=train_dataset,
peft_config=peft_config,
)
# Generation config
generation_kwargs = {
"max_new_tokens": script_args.max_new_tokens,
"temperature": script_args.temperature,
"top_p": script_args.top_p,
"do_sample": True,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
}
# Start training
ppo_trainer.train()
# Save final model
if ppo_trainer.accelerator.is_local_main_process:
print(f"Saving final model to {ppo_config.output_dir}")
ppo_trainer.save_pretrained(ppo_config.output_dir)
tokenizer.save_pretrained(ppo_config.output_dir)
print("PPO training completed!")
if __name__ == "__main__":
main()