from unsloth import PatchDPOTrainer from unsloth import FastLanguageModel import torch import os import re from typing import List, Literal, Optional import pprint from transformers import TrainingArguments from trl import DPOTrainer, DPOConfig from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk from datasets.builder import DatasetGenerationError PatchDPOTrainer() max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. model, tokenizer = FastLanguageModel.from_pretrained( model_name = "hahayang012/Mistral-Small-3.1-24B-Base-2503-SFT", # Choose ANY! eg mistralai/Mistral-7B-Instruct-v0.2 max_seq_length = max_seq_length, dtype = dtype, load_in_4bit = load_in_4bit, # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf ) ds1 = load_dataset("parquet", data_files="/home/dataset/data/ds1.parquet") ds2 = load_dataset("parquet", data_files="/home/dataset/data/ds2.parquet") ds3 = load_dataset("parquet", data_files="/home/dataset/data/ds3.parquet") ds4 = load_dataset("parquet", data_files="/home/dataset/data/ds4.parquet") def prepare_dpo_dataset(dataset): dataset = dataset.map(lambda x: { "prompt": x["chosen_prompt"], "chosen": x["chosen"], "rejected": x["reject"] }) return dataset.select_columns(["prompt", "chosen", "rejected"]) ds1 = prepare_dpo_dataset(ds1) ds2 = prepare_dpo_dataset(ds2) ds3 = prepare_dpo_dataset(ds3) ds4 = prepare_dpo_dataset(ds4) model = FastLanguageModel.get_peft_model( model, r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",], lora_alpha = 64, lora_dropout = 0, # Currently only supports dropout = 0 bias = "none", # Currently only supports bias = "none" # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes! use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context random_state = 3407, use_rslora = False, # We support rank stabilized LoRA loftq_config = None, # And LoftQ ) dpo_trainer = DPOTrainer( model = model, ref_model = None, args = DPOConfig( per_device_train_batch_size = 2, gradient_accumulation_steps = 4, warmup_ratio = 0.1, num_train_epochs = 3, learning_rate = 5e-6, logging_steps = 1, optim = "adamw_8bit", weight_decay = 0.0, lr_scheduler_type = "linear", seed = 42, output_dir = "outputs", report_to = "none", # Use this for WandB etc ), beta = 0.1, train_dataset = raw_datasets["train"], # eval_dataset = raw_datasets["test"], tokenizer = tokenizer, max_length = 1024, max_prompt_length = 512, ) dpo_trainer.train()