| | from unsloth import FastLanguageModel |
| | import torch |
| | from unsloth.chat_templates import get_chat_template |
| | from datasets import load_dataset,concatenate_datasets |
| | from trl import SFTTrainer |
| | from transformers import TrainingArguments |
| | from unsloth import is_bfloat16_supported |
| | import wandb |
| | from unsloth.chat_templates import standardize_sharegpt |
| | from datasets import Dataset |
| |
|
| | max_seq_length = 2048 |
| | dtype = None |
| | load_in_4bit = True |
| | outputs="/home/Mistral-Small-3.1-24B-Base-2503/outputs" |
| |
|
| | wandb.init( |
| | project="Mistral-Small-3.1-24B-Base-2503-SFT", |
| | name="run3", |
| | ) |
| |
|
| | model, tokenizer = FastLanguageModel.from_pretrained( |
| | model_name = "mistralai/Mistral-Small-3.1-24B-Base-2503", |
| | max_seq_length = max_seq_length, |
| | dtype = dtype, |
| | load_in_4bit = load_in_4bit, |
| | |
| | ) |
| |
|
| | model = FastLanguageModel.get_peft_model( |
| | model, |
| | r = 16, |
| | target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", |
| | "gate_proj", "up_proj", "down_proj",], |
| | lora_alpha = 16, |
| | lora_dropout = 0, |
| | bias = "none", |
| | |
| | use_gradient_checkpointing = "unsloth", |
| | random_state = 3407, |
| | use_rslora = False, |
| | loftq_config = None, |
| | ) |
| |
|
| | tokenizer = get_chat_template( |
| | tokenizer, |
| | chat_template = "chatml", |
| | map_eos_token = True, |
| | ) |
| | def remove_unrelated_columns(dataset): |
| | return dataset.select_columns(["conversations"]) |
| |
|
| | def clean_shareGPT_remove_weight(dataset): |
| | |
| | cleaned = [] |
| | for item in dataset: |
| | new_convos = [{"from": x["from"], "value": x["value"]} for x in item["conversations"]] |
| | cleaned.append({"conversations": new_convos}) |
| | return Dataset.from_list(cleaned) |
| |
|
| |
|
| | def formatting_prompts_func(examples): |
| | convos = examples["conversations"] |
| | texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos] |
| | return { "text" : texts, } |
| | pass |
| |
|
| | def reorder_conversations(example): |
| | convos = [] |
| | for message in example["conversations"]: |
| | |
| | convos.append({ |
| | "role": message["role"], |
| | "content": message["content"], |
| | }) |
| | return {"conversations": convos} |
| |
|
| | ds1 = load_dataset("Gryphe/Sonnet3.5-Charcard-Roleplay", split = "train") |
| | ds1 = standardize_sharegpt(ds1) |
| | ds1 = ds1.map(reorder_conversations) |
| | ds1 = ds1.map(formatting_prompts_func, batched = True,) |
| |
|
| | ds2 = load_dataset("zerofata/Roleplay-Anime-Characters", split = "train") |
| | ds2 = ds2.rename_column("messages", "conversations") |
| | ds2 = remove_unrelated_columns(ds2) |
| | ds2 = ds2.map(reorder_conversations) |
| | ds2 = ds2.map(formatting_prompts_func, batched = True,) |
| |
|
| | ds3 = load_dataset("Open-Orca/SlimOrca", split="train") |
| | ds3 = remove_unrelated_columns(ds3) |
| | ds3 = clean_shareGPT_remove_weight(ds3) |
| | ds3 = standardize_sharegpt(ds3) |
| | ds3 = ds3.map(reorder_conversations) |
| | ds3 = ds3.select(range(20000)) |
| | ds3 = ds3.map(formatting_prompts_func, batched = True,) |
| |
|
| | |
| | ds1 = ds1.remove_columns([col for col in ds1.column_names if col != "text"]) |
| | ds2 = ds2.remove_columns([col for col in ds2.column_names if col != "text"]) |
| | ds3 = ds3.remove_columns([col for col in ds3.column_names if col != "text"]) |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | ds = concatenate_datasets([ds1, ds2, ds3]) |
| |
|
| | trainer = SFTTrainer( |
| | model = model, |
| | tokenizer = tokenizer, |
| | train_dataset = ds, |
| | dataset_text_field = "text", |
| | max_seq_length = max_seq_length, |
| | dataset_num_proc = 2, |
| | packing = False, |
| | args = TrainingArguments( |
| | per_device_train_batch_size = 4, |
| | gradient_accumulation_steps = 4, |
| | warmup_ratio = 0.01, |
| | |
| | |
| | learning_rate = 4e-5, |
| | fp16 = not is_bfloat16_supported(), |
| | bf16 = is_bfloat16_supported(), |
| | logging_steps = 10, |
| | optim = "adamw_8bit", |
| | weight_decay = 0.01, |
| | lr_scheduler_type = "cosine", |
| | seed = 3407, |
| | output_dir = "outputs", |
| | report_to="wandb", |
| | run_name="run3", |
| | ), |
| | ) |
| |
|
| | trainer_stats = trainer.train() |
| | model.save_pretrained_merged("/home/Mistral-Small-3.1-24B-Base-2503/model_1", tokenizer, save_method = "merged_16bit",) |
| | model.push_to_hub_merged("hahayang012/Mistral-Small-3.1-24B-Base-2503-SFT-1", tokenizer, save_method = "merged_16bit", token = "还没写") |