File size: 5,596 Bytes
fa4458a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
import tyro
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
from trl import RewardConfig, RewardTrainer, is_xpu_available
tqdm.pandas()
@dataclass
class ScriptArguments:
model_name: str = "facebook/opt-350m"
"""the model name"""
dataset_name: str = "Anthropic/hh-rlhf"
"""the dataset name"""
dataset_text_field: str = "text"
"""the text field of the dataset"""
eval_split: str = "none"
"""the dataset split to evaluate on; default to 'none' (no evaluation)"""
load_in_8bit: bool = False
"""load the model in 8 bits precision"""
load_in_4bit: bool = False
"""load the model in 4 bits precision"""
trust_remote_code: bool = True
"""Enable `trust_remote_code`"""
reward_config: RewardConfig = field(
default_factory=lambda: RewardConfig(
output_dir="output",
per_device_train_batch_size=64,
num_train_epochs=1,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
learning_rate=1.41e-5,
report_to="tensorboard",
remove_unused_columns=False,
optim="adamw_torch",
logging_steps=500,
evaluation_strategy="no",
max_length=512,
)
)
use_peft: bool = False
"""whether to use peft"""
peft_config: Optional[LoraConfig] = field(
default_factory=lambda: LoraConfig(
r=16,
lora_alpha=16,
bias="none",
task_type="SEQ_CLS",
modules_to_save=["scores"],
),
)
args = tyro.cli(ScriptArguments)
args.reward_config.evaluation_strategy = "steps" if args.eval_split != "none" else "no"
# Step 1: Load the model
if args.load_in_8bit and args.load_in_4bit:
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif args.load_in_8bit or args.load_in_4bit:
quantization_config = BitsAndBytesConfig(load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit)
# Copy the model to each device
device_map = (
{"": f"xpu:{Accelerator().local_process_index}"}
if is_xpu_available()
else {"": Accelerator().local_process_index}
)
else:
device_map = None
quantization_config = None
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name,
quantization_config=quantization_config,
device_map=device_map,
trust_remote_code=args.trust_remote_code,
num_labels=1,
)
# Step 2: Load the dataset and pre-process it
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
train_dataset = load_dataset(args.dataset_name, split="train")
# Tokenize chosen/rejected pairs of inputs
# Adapt this section to your needs for custom datasets
def preprocess_function(examples):
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
tokenized_chosen = tokenizer(chosen)
tokenized_rejected = tokenizer(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
return new_examples
# Preprocess the dataset and filter out examples that are longer than args.max_length
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
num_proc=4,
)
train_dataset = train_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length
and len(x["input_ids_rejected"]) <= args.reward_config.max_length
)
if args.eval_split == "none":
eval_dataset = None
else:
eval_dataset = load_dataset(args.dataset_name, split=args.eval_split)
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
num_proc=4,
)
eval_dataset = eval_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length
and len(x["input_ids_rejected"]) <= args.reward_config.max_length
)
# Step 4: Define the LoraConfig
if args.use_peft:
peft_config = args.peft_config
else:
peft_config = None
# Step 5: Define the Trainer
trainer = RewardTrainer(
model=model,
tokenizer=tokenizer,
args=args.reward_config,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
)
trainer.train()
|