GPT2 / src /sft.py
svshrithik12's picture
Upload folder using huggingface_hub
c96ac34 verified
import os
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
def to_single_turn_text(example):
msgs = example["messages"]
user = next(m["content"] for m in msgs if m["role"] == "user")
assistant = next(m["content"] for m in msgs if m["role"] == "assistant")
text = f"User: {user}\nAssistant: {assistant}"
return {"text": text}
def main(
base_model_dir="hf_pretrained",
out_dir="hf_sft_lora",
max_seq_len=512,
):
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
ds = ds.map(to_single_turn_text, remove_columns=ds.column_names)
tok = AutoTokenizer.from_pretrained(base_model_dir, use_fast=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
base_model_dir,
torch_dtype=torch.bfloat16,
device_map="auto",
)
lora = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["c_attn", "c_proj", "c_fc"],
)
model = get_peft_model(model, lora)
collator = DataCollatorForCompletionOnlyLM(
response_template="Assistant:",
tokenizer=tok,
)
args = TrainingArguments(
output_dir=out_dir,
per_device_train_batch_size=8,
gradient_accumulation_steps=8,
learning_rate=2e-4,
num_train_epochs=1,
bf16=True,
logging_steps=25,
save_steps=500,
save_total_limit=3,
report_to="none",
)
trainer = SFTTrainer(
model=model,
tokenizer=tok,
train_dataset=ds,
args=args,
max_seq_length=max_seq_len,
data_collator=collator,
packing=False,
dataset_text_field="text",
)
trainer.train()
trainer.save_model(out_dir)
tok.save_pretrained(out_dir)
if __name__ == "__main__":
main()