rm_code / sft.py
hahayang012's picture
Upload folder using huggingface_hub
d8a76be verified
from unsloth import FastLanguageModel
import torch
from unsloth.chat_templates import get_chat_template
from datasets import load_dataset,concatenate_datasets
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
import wandb
from unsloth.chat_templates import standardize_sharegpt
from datasets import Dataset
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
outputs="/home/Mistral-Small-3.1-24B-Base-2503/outputs"
wandb.init(
project="Mistral-Small-3.1-24B-Base-2503-SFT",
name="run3",
)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "mistralai/Mistral-Small-3.1-24B-Base-2503", # Choose ANY! eg teknium/OpenHermes-2.5-Mistral-7B
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
model = FastLanguageModel.get_peft_model(
model,
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
tokenizer = get_chat_template(
tokenizer,
chat_template = "chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
map_eos_token = True, # Maps <|im_end|> to </s> instead
)
def remove_unrelated_columns(dataset):
return dataset.select_columns(["conversations"])
def clean_shareGPT_remove_weight(dataset):
# 先清洗 conversations 字段
cleaned = []
for item in dataset:
new_convos = [{"from": x["from"], "value": x["value"]} for x in item["conversations"]]
cleaned.append({"conversations": new_convos})
return Dataset.from_list(cleaned)
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
return { "text" : texts, }
pass
def reorder_conversations(example):
convos = []
for message in example["conversations"]:
# 明确构建字段顺序:role 在前,content 在后
convos.append({
"role": message["role"],
"content": message["content"],
})
return {"conversations": convos}
ds1 = load_dataset("Gryphe/Sonnet3.5-Charcard-Roleplay", split = "train")
ds1 = standardize_sharegpt(ds1)
ds1 = ds1.map(reorder_conversations)
ds1 = ds1.map(formatting_prompts_func, batched = True,)
ds2 = load_dataset("zerofata/Roleplay-Anime-Characters", split = "train")
ds2 = ds2.rename_column("messages", "conversations")
ds2 = remove_unrelated_columns(ds2)
ds2 = ds2.map(reorder_conversations)
ds2 = ds2.map(formatting_prompts_func, batched = True,)
ds3 = load_dataset("Open-Orca/SlimOrca", split="train")
ds3 = remove_unrelated_columns(ds3)
ds3 = clean_shareGPT_remove_weight(ds3)
ds3 = standardize_sharegpt(ds3)
ds3 = ds3.map(reorder_conversations)
ds3 = ds3.select(range(20000))
ds3 = ds3.map(formatting_prompts_func, batched = True,)
# 保留 text 字段,去掉 conversations
ds1 = ds1.remove_columns([col for col in ds1.column_names if col != "text"])
ds2 = ds2.remove_columns([col for col in ds2.column_names if col != "text"])
ds3 = ds3.remove_columns([col for col in ds3.column_names if col != "text"])
# print(ds1.features)
# print(ds2.features)
# print(ds3.features)
# for i in range(3):
# print("=" * 60)
# print(ds1[i]["text"])
# print(ds2[i]["text"])
# print(ds3[i]["text"])
ds = concatenate_datasets([ds1, ds2, ds3])
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = ds,
dataset_text_field = "text",
max_seq_length = max_seq_length,
dataset_num_proc = 2,
packing = False, # Can make training 5x faster for short sequences.
args = TrainingArguments(
per_device_train_batch_size = 4,
gradient_accumulation_steps = 4,
warmup_ratio = 0.01, # 3% of total steps
#warmup_steps = 5,
#max_steps = 60,
learning_rate = 4e-5,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 10,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "cosine",
seed = 3407,
output_dir = "outputs",
report_to="wandb",
run_name="run3",
),
)
trainer_stats = trainer.train()
model.save_pretrained_merged("/home/Mistral-Small-3.1-24B-Base-2503/model_1", tokenizer, save_method = "merged_16bit",)
model.push_to_hub_merged("hahayang012/Mistral-Small-3.1-24B-Base-2503-SFT-1", tokenizer, save_method = "merged_16bit", token = "还没写")