File size: 5,318 Bytes
d8a76be | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | from unsloth import FastLanguageModel
import torch
from unsloth.chat_templates import get_chat_template
from datasets import load_dataset,concatenate_datasets
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
import wandb
from unsloth.chat_templates import standardize_sharegpt
from datasets import Dataset
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
outputs="/home/Mistral-Small-3.1-24B-Base-2503/outputs"
wandb.init(
project="Mistral-Small-3.1-24B-Base-2503-SFT",
name="run3",
)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "mistralai/Mistral-Small-3.1-24B-Base-2503", # Choose ANY! eg teknium/OpenHermes-2.5-Mistral-7B
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
model = FastLanguageModel.get_peft_model(
model,
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
tokenizer = get_chat_template(
tokenizer,
chat_template = "chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
map_eos_token = True, # Maps <|im_end|> to </s> instead
)
def remove_unrelated_columns(dataset):
return dataset.select_columns(["conversations"])
def clean_shareGPT_remove_weight(dataset):
# 先清洗 conversations 字段
cleaned = []
for item in dataset:
new_convos = [{"from": x["from"], "value": x["value"]} for x in item["conversations"]]
cleaned.append({"conversations": new_convos})
return Dataset.from_list(cleaned)
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
return { "text" : texts, }
pass
def reorder_conversations(example):
convos = []
for message in example["conversations"]:
# 明确构建字段顺序:role 在前,content 在后
convos.append({
"role": message["role"],
"content": message["content"],
})
return {"conversations": convos}
ds1 = load_dataset("Gryphe/Sonnet3.5-Charcard-Roleplay", split = "train")
ds1 = standardize_sharegpt(ds1)
ds1 = ds1.map(reorder_conversations)
ds1 = ds1.map(formatting_prompts_func, batched = True,)
ds2 = load_dataset("zerofata/Roleplay-Anime-Characters", split = "train")
ds2 = ds2.rename_column("messages", "conversations")
ds2 = remove_unrelated_columns(ds2)
ds2 = ds2.map(reorder_conversations)
ds2 = ds2.map(formatting_prompts_func, batched = True,)
ds3 = load_dataset("Open-Orca/SlimOrca", split="train")
ds3 = remove_unrelated_columns(ds3)
ds3 = clean_shareGPT_remove_weight(ds3)
ds3 = standardize_sharegpt(ds3)
ds3 = ds3.map(reorder_conversations)
ds3 = ds3.select(range(20000))
ds3 = ds3.map(formatting_prompts_func, batched = True,)
# 保留 text 字段,去掉 conversations
ds1 = ds1.remove_columns([col for col in ds1.column_names if col != "text"])
ds2 = ds2.remove_columns([col for col in ds2.column_names if col != "text"])
ds3 = ds3.remove_columns([col for col in ds3.column_names if col != "text"])
# print(ds1.features)
# print(ds2.features)
# print(ds3.features)
# for i in range(3):
# print("=" * 60)
# print(ds1[i]["text"])
# print(ds2[i]["text"])
# print(ds3[i]["text"])
ds = concatenate_datasets([ds1, ds2, ds3])
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = ds,
dataset_text_field = "text",
max_seq_length = max_seq_length,
dataset_num_proc = 2,
packing = False, # Can make training 5x faster for short sequences.
args = TrainingArguments(
per_device_train_batch_size = 4,
gradient_accumulation_steps = 4,
warmup_ratio = 0.01, # 3% of total steps
#warmup_steps = 5,
#max_steps = 60,
learning_rate = 4e-5,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 10,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "cosine",
seed = 3407,
output_dir = "outputs",
report_to="wandb",
run_name="run3",
),
)
trainer_stats = trainer.train()
model.save_pretrained_merged("/home/Mistral-Small-3.1-24B-Base-2503/model_1", tokenizer, save_method = "merged_16bit",)
model.push_to_hub_merged("hahayang012/Mistral-Small-3.1-24B-Base-2503-SFT-1", tokenizer, save_method = "merged_16bit", token = "还没写") |