|
|
from datasets import load_dataset |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments |
|
|
from trl import DPOTrainer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = "mistralai/Mistral-7B-Instruct-v0.2" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
load_in_4bit=True, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = load_dataset("json", data_files="dpo_data.jsonl")["train"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir="./dpo_output", |
|
|
per_device_train_batch_size=2, |
|
|
gradient_accumulation_steps=4, |
|
|
num_train_epochs=2, |
|
|
learning_rate=5e-6, |
|
|
bf16=True, |
|
|
logging_steps=10, |
|
|
save_steps=500, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer = DPOTrainer( |
|
|
model=model, |
|
|
ref_model=None, |
|
|
tokenizer=tokenizer, |
|
|
train_dataset=dataset, |
|
|
beta=0.1, |
|
|
args=training_args, |
|
|
) |
|
|
|
|
|
trainer.train() |
|
|
|