| | |
| | """ |
| | PPO Training Script for MultiPref Dataset |
| | Uses trained reward model for policy optimization with proper distributed training |
| | Based on: https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py |
| | """ |
| |
|
| | import os |
| | from dataclasses import dataclass, field |
| | from typing import Optional |
| | import torch |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | AutoModelForSequenceClassification, |
| | HfArgumentParser, |
| | ) |
| | from datasets import load_from_disk, Dataset |
| | from trl import PPOTrainer, PPOConfig, ModelConfig |
| | from accelerate import PartialState |
| |
|
| | @dataclass |
| | class ScriptArguments: |
| | """Script arguments for PPO training""" |
| | model_name: str = field( |
| | default="meta-llama/Meta-Llama-3-8B-Instruct", |
| | metadata={"help": "Base policy model name"} |
| | ) |
| | reward_model_path: str = field( |
| | default="./results/reward_model", |
| | metadata={"help": "Path to trained reward model"} |
| | ) |
| | dataset_path: str = field( |
| | default="data/multipref_train.hf", |
| | metadata={"help": "Path to training dataset"} |
| | ) |
| | max_prompt_length: int = field( |
| | default=512, |
| | metadata={"help": "Maximum prompt length"} |
| | ) |
| | max_new_tokens: int = field( |
| | default=256, |
| | metadata={"help": "Maximum new tokens to generate"} |
| | ) |
| | temperature: float = field( |
| | default=0.7, |
| | metadata={"help": "Sampling temperature"} |
| | ) |
| | top_p: float = field( |
| | default=0.95, |
| | metadata={"help": "Top-p sampling parameter"} |
| | ) |
| | use_peft: bool = field( |
| | default=False, |
| | metadata={"help": "Whether to use PEFT for training"} |
| | ) |
| |
|
| |
|
| | def prepare_dataset_for_ppo(dataset, tokenizer, max_prompt_length=512): |
| | """ |
| | Prepare the multipref dataset for PPO training |
| | Extract unique prompts for policy to generate responses |
| | """ |
| | |
| | prompts_seen = set() |
| | unique_examples = [] |
| | |
| | for example in dataset: |
| | if example['prompt'] not in prompts_seen: |
| | prompts_seen.add(example['prompt']) |
| | unique_examples.append(example) |
| | |
| | print(f"Original dataset size: {len(dataset)}") |
| | print(f"Unique prompts: {len(unique_examples)}") |
| | |
| | def format_prompt(example): |
| | |
| | prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{example['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" |
| | return {"query": prompt} |
| | |
| | |
| | unique_dataset = Dataset.from_list(unique_examples) |
| | formatted_dataset = unique_dataset.map(format_prompt) |
| | |
| | |
| | def tokenize(example): |
| | tokens = tokenizer( |
| | example["query"], |
| | truncation=True, |
| | max_length=max_prompt_length, |
| | padding=False, |
| | return_tensors=None, |
| | ) |
| | example["input_ids"] = tokens["input_ids"] |
| | example["query"] = tokenizer.decode(tokens["input_ids"]) |
| | return example |
| | |
| | tokenized_dataset = formatted_dataset.map(tokenize, batched=False) |
| | tokenized_dataset.set_format(type="torch") |
| | |
| | return tokenized_dataset |
| |
|
| |
|
| | def main(): |
| | parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig)) |
| | script_args, ppo_config, model_config = parser.parse_args_into_dataclasses() |
| | |
| | |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | model_config.model_name_or_path or script_args.model_name, |
| | trust_remote_code=model_config.trust_remote_code, |
| | ) |
| | |
| | |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | tokenizer.pad_token_id = tokenizer.eos_token_id |
| | |
| | |
| | with PartialState().local_main_process_first(): |
| | print("Loading and preparing dataset...") |
| | dataset = load_from_disk(script_args.dataset_path) |
| | train_dataset = prepare_dataset_for_ppo( |
| | dataset, |
| | tokenizer, |
| | max_prompt_length=script_args.max_prompt_length |
| | ) |
| | |
| | print(f"Prepared dataset size: {len(train_dataset)}") |
| | if len(train_dataset) > 0: |
| | print("Sample query:", train_dataset[0]['query'][:200] + "...") |
| | |
| | |
| | policy = AutoModelForCausalLM.from_pretrained( |
| | model_config.model_name_or_path or script_args.model_name, |
| | trust_remote_code=model_config.trust_remote_code, |
| | torch_dtype=model_config.torch_dtype, |
| | ) |
| | |
| | |
| | reward_model = AutoModelForSequenceClassification.from_pretrained( |
| | script_args.reward_model_path, |
| | num_labels=1, |
| | trust_remote_code=model_config.trust_remote_code, |
| | torch_dtype=model_config.torch_dtype, |
| | ) |
| | |
| | |
| | reward_tokenizer = AutoTokenizer.from_pretrained( |
| | script_args.reward_model_path, |
| | trust_remote_code=model_config.trust_remote_code, |
| | ) |
| | if reward_tokenizer.pad_token is None: |
| | reward_tokenizer.pad_token = reward_tokenizer.eos_token |
| | reward_tokenizer.pad_token_id = reward_tokenizer.eos_token_id |
| | |
| | |
| | peft_config = None |
| | if script_args.use_peft: |
| | from peft import LoraConfig, TaskType |
| | peft_config = LoraConfig( |
| | task_type=TaskType.CAUSAL_LM, |
| | inference_mode=False, |
| | r=16, |
| | lora_alpha=32, |
| | lora_dropout=0.1, |
| | target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], |
| | ) |
| | |
| | |
| | |
| | ppo_trainer = PPOTrainer( |
| | config=ppo_config, |
| | processing_class=tokenizer, |
| | policy=policy, |
| | ref_policy=None, |
| | reward_model=reward_model, |
| | train_dataset=train_dataset, |
| | peft_config=peft_config, |
| | ) |
| | |
| | |
| | generation_kwargs = { |
| | "max_new_tokens": script_args.max_new_tokens, |
| | "temperature": script_args.temperature, |
| | "top_p": script_args.top_p, |
| | "do_sample": True, |
| | "pad_token_id": tokenizer.pad_token_id, |
| | "eos_token_id": tokenizer.eos_token_id, |
| | } |
| | |
| | |
| | ppo_trainer.train() |
| | |
| | |
| | if ppo_trainer.accelerator.is_local_main_process: |
| | print(f"Saving final model to {ppo_config.output_dir}") |
| | ppo_trainer.save_pretrained(ppo_config.output_dir) |
| | tokenizer.save_pretrained(ppo_config.output_dir) |
| | print("PPO training completed!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |