from unsloth import FastLanguageModel import torch from unsloth.chat_templates import get_chat_template from datasets import load_dataset,concatenate_datasets from trl import SFTTrainer from transformers import TrainingArguments from unsloth import is_bfloat16_supported import wandb from unsloth.chat_templates import standardize_sharegpt from datasets import Dataset max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. outputs="/home/Mistral-Small-3.1-24B-Base-2503/outputs" wandb.init( project="Mistral-Small-3.1-24B-Base-2503-SFT", name="run3", ) model, tokenizer = FastLanguageModel.from_pretrained( model_name = "mistralai/Mistral-Small-3.1-24B-Base-2503", # Choose ANY! eg teknium/OpenHermes-2.5-Mistral-7B max_seq_length = max_seq_length, dtype = dtype, load_in_4bit = load_in_4bit, # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf ) model = FastLanguageModel.get_peft_model( model, r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",], lora_alpha = 16, lora_dropout = 0, # Supports any, but = 0 is optimized bias = "none", # Supports any, but = "none" is optimized # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes! use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context random_state = 3407, use_rslora = False, # We support rank stabilized LoRA loftq_config = None, # And LoftQ ) tokenizer = get_chat_template( tokenizer, chat_template = "chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth map_eos_token = True, # Maps <|im_end|> to instead ) def remove_unrelated_columns(dataset): return dataset.select_columns(["conversations"]) def clean_shareGPT_remove_weight(dataset): # 先清洗 conversations 字段 cleaned = [] for item in dataset: new_convos = [{"from": x["from"], "value": x["value"]} for x in item["conversations"]] cleaned.append({"conversations": new_convos}) return Dataset.from_list(cleaned) def formatting_prompts_func(examples): convos = examples["conversations"] texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos] return { "text" : texts, } pass def reorder_conversations(example): convos = [] for message in example["conversations"]: # 明确构建字段顺序:role 在前,content 在后 convos.append({ "role": message["role"], "content": message["content"], }) return {"conversations": convos} ds1 = load_dataset("Gryphe/Sonnet3.5-Charcard-Roleplay", split = "train") ds1 = standardize_sharegpt(ds1) ds1 = ds1.map(reorder_conversations) ds1 = ds1.map(formatting_prompts_func, batched = True,) ds2 = load_dataset("zerofata/Roleplay-Anime-Characters", split = "train") ds2 = ds2.rename_column("messages", "conversations") ds2 = remove_unrelated_columns(ds2) ds2 = ds2.map(reorder_conversations) ds2 = ds2.map(formatting_prompts_func, batched = True,) ds3 = load_dataset("Open-Orca/SlimOrca", split="train") ds3 = remove_unrelated_columns(ds3) ds3 = clean_shareGPT_remove_weight(ds3) ds3 = standardize_sharegpt(ds3) ds3 = ds3.map(reorder_conversations) ds3 = ds3.select(range(20000)) ds3 = ds3.map(formatting_prompts_func, batched = True,) # 保留 text 字段,去掉 conversations ds1 = ds1.remove_columns([col for col in ds1.column_names if col != "text"]) ds2 = ds2.remove_columns([col for col in ds2.column_names if col != "text"]) ds3 = ds3.remove_columns([col for col in ds3.column_names if col != "text"]) # print(ds1.features) # print(ds2.features) # print(ds3.features) # for i in range(3): # print("=" * 60) # print(ds1[i]["text"]) # print(ds2[i]["text"]) # print(ds3[i]["text"]) ds = concatenate_datasets([ds1, ds2, ds3]) trainer = SFTTrainer( model = model, tokenizer = tokenizer, train_dataset = ds, dataset_text_field = "text", max_seq_length = max_seq_length, dataset_num_proc = 2, packing = False, # Can make training 5x faster for short sequences. args = TrainingArguments( per_device_train_batch_size = 4, gradient_accumulation_steps = 4, warmup_ratio = 0.01, # 3% of total steps #warmup_steps = 5, #max_steps = 60, learning_rate = 4e-5, fp16 = not is_bfloat16_supported(), bf16 = is_bfloat16_supported(), logging_steps = 10, optim = "adamw_8bit", weight_decay = 0.01, lr_scheduler_type = "cosine", seed = 3407, output_dir = "outputs", report_to="wandb", run_name="run3", ), ) trainer_stats = trainer.train() model.save_pretrained_merged("/home/Mistral-Small-3.1-24B-Base-2503/model_1", tokenizer, save_method = "merged_16bit",) model.push_to_hub_merged("hahayang012/Mistral-Small-3.1-24B-Base-2503-SFT-1", tokenizer, save_method = "merged_16bit", token = "还没写")