| import os | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments | |
| from peft import LoraConfig, get_peft_model | |
| from trl import SFTTrainer, DataCollatorForCompletionOnlyLM | |
| def to_single_turn_text(example): | |
| msgs = example["messages"] | |
| user = next(m["content"] for m in msgs if m["role"] == "user") | |
| assistant = next(m["content"] for m in msgs if m["role"] == "assistant") | |
| text = f"User: {user}\nAssistant: {assistant}" | |
| return {"text": text} | |
| def main( | |
| base_model_dir="hf_pretrained", | |
| out_dir="hf_sft_lora", | |
| max_seq_len=512, | |
| ): | |
| ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft") | |
| ds = ds.map(to_single_turn_text, remove_columns=ds.column_names) | |
| tok = AutoTokenizer.from_pretrained(base_model_dir, use_fast=True) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model_dir, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| lora = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules=["c_attn", "c_proj", "c_fc"], | |
| ) | |
| model = get_peft_model(model, lora) | |
| collator = DataCollatorForCompletionOnlyLM( | |
| response_template="Assistant:", | |
| tokenizer=tok, | |
| ) | |
| args = TrainingArguments( | |
| output_dir=out_dir, | |
| per_device_train_batch_size=8, | |
| gradient_accumulation_steps=8, | |
| learning_rate=2e-4, | |
| num_train_epochs=1, | |
| bf16=True, | |
| logging_steps=25, | |
| save_steps=500, | |
| save_total_limit=3, | |
| report_to="none", | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, | |
| tokenizer=tok, | |
| train_dataset=ds, | |
| args=args, | |
| max_seq_length=max_seq_len, | |
| data_collator=collator, | |
| packing=False, | |
| dataset_text_field="text", | |
| ) | |
| trainer.train() | |
| trainer.save_model(out_dir) | |
| tok.save_pretrained(out_dir) | |
| if __name__ == "__main__": | |
| main() |