File size: 5,318 Bytes
d8a76be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from unsloth import FastLanguageModel
import torch
from unsloth.chat_templates import get_chat_template
from datasets import load_dataset,concatenate_datasets
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
import wandb
from unsloth.chat_templates import standardize_sharegpt
from datasets import Dataset

max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
outputs="/home/Mistral-Small-3.1-24B-Base-2503/outputs"

wandb.init(
    project="Mistral-Small-3.1-24B-Base-2503-SFT",
    name="run3",
)

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "mistralai/Mistral-Small-3.1-24B-Base-2503", # Choose ANY! eg teknium/OpenHermes-2.5-Mistral-7B
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
    map_eos_token = True, # Maps <|im_end|> to </s> instead
)
def remove_unrelated_columns(dataset):
  return dataset.select_columns(["conversations"])

def clean_shareGPT_remove_weight(dataset):
    # 先清洗 conversations 字段
    cleaned = []
    for item in dataset:
        new_convos = [{"from": x["from"], "value": x["value"]} for x in item["conversations"]]
        cleaned.append({"conversations": new_convos})
    return Dataset.from_list(cleaned)


def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }
pass

def reorder_conversations(example):
    convos = []
    for message in example["conversations"]:
        # 明确构建字段顺序:role 在前,content 在后
        convos.append({
            "role": message["role"],
            "content": message["content"],
        })
    return {"conversations": convos}

ds1 = load_dataset("Gryphe/Sonnet3.5-Charcard-Roleplay", split = "train")
ds1 = standardize_sharegpt(ds1)
ds1 = ds1.map(reorder_conversations)
ds1 = ds1.map(formatting_prompts_func, batched = True,)

ds2 = load_dataset("zerofata/Roleplay-Anime-Characters", split = "train")
ds2 = ds2.rename_column("messages", "conversations")
ds2 = remove_unrelated_columns(ds2)
ds2 = ds2.map(reorder_conversations)
ds2 = ds2.map(formatting_prompts_func, batched = True,)

ds3 = load_dataset("Open-Orca/SlimOrca", split="train")
ds3 = remove_unrelated_columns(ds3)
ds3 = clean_shareGPT_remove_weight(ds3)
ds3 = standardize_sharegpt(ds3)
ds3 = ds3.map(reorder_conversations)
ds3 = ds3.select(range(20000))
ds3 = ds3.map(formatting_prompts_func, batched = True,)

# 保留 text 字段,去掉 conversations
ds1 = ds1.remove_columns([col for col in ds1.column_names if col != "text"])
ds2 = ds2.remove_columns([col for col in ds2.column_names if col != "text"])
ds3 = ds3.remove_columns([col for col in ds3.column_names if col != "text"])
# print(ds1.features)
# print(ds2.features)
# print(ds3.features)

# for i in range(3):
#     print("=" * 60)
#     print(ds1[i]["text"])
#     print(ds2[i]["text"])
#     print(ds3[i]["text"])

ds = concatenate_datasets([ds1, ds2, ds3])

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = ds,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 4,
        gradient_accumulation_steps = 4,
        warmup_ratio = 0.01, # 3% of total steps
        #warmup_steps = 5,
        #max_steps = 60,
        learning_rate = 4e-5,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 10,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 3407,
        output_dir = "outputs",
        report_to="wandb",
        run_name="run3",
    ),
)

trainer_stats = trainer.train()
model.save_pretrained_merged("/home/Mistral-Small-3.1-24B-Base-2503/model_1", tokenizer, save_method = "merged_16bit",)
model.push_to_hub_merged("hahayang012/Mistral-Small-3.1-24B-Base-2503-SFT-1", tokenizer, save_method = "merged_16bit", token = "还没写")