Upload train.py
Browse files
train.py
CHANGED
|
@@ -1,8 +1,17 @@
|
|
|
|
|
|
|
|
| 1 |
from unsloth import FastModel
|
| 2 |
-
from datasets import load_dataset
|
| 3 |
-
from trl import SFTConfig, SFTTrainer
|
| 4 |
from unsloth.chat_templates import get_chat_template, train_on_responses_only
|
| 5 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
max_seq_length = 2048
|
| 8 |
|
|
@@ -12,10 +21,13 @@ model, tokenizer = FastModel.from_pretrained(
|
|
| 12 |
load_in_4bit=False, # 4 bit quantization to reduce memory
|
| 13 |
load_in_8bit=False, # [NEW!] A bit more accurate, uses 2x memory
|
| 14 |
full_finetuning=False, # [NEW!] We have full finetuning now!
|
| 15 |
-
dtype=torch.bfloat16,
|
| 16 |
# token = "hf_...", # use one if using gated models
|
| 17 |
)
|
| 18 |
|
|
|
|
|
|
|
|
|
|
| 19 |
model = FastModel.get_peft_model(
|
| 20 |
model,
|
| 21 |
r=128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
|
@@ -43,8 +55,8 @@ tokenizer = get_chat_template(
|
|
| 43 |
chat_template="gemma3",
|
| 44 |
)
|
| 45 |
|
| 46 |
-
|
| 47 |
-
dataset = load_dataset("qmaru/gemma3-sms", split="train")
|
| 48 |
|
| 49 |
dataset = dataset.shuffle(seed=42)
|
| 50 |
|
|
@@ -96,7 +108,8 @@ trainer = SFTTrainer(
|
|
| 96 |
remove_unused_columns=True,
|
| 97 |
dataloader_pin_memory=True,
|
| 98 |
dataloader_num_workers=4,
|
| 99 |
-
bf16=
|
|
|
|
| 100 |
),
|
| 101 |
)
|
| 102 |
|
|
@@ -109,5 +122,8 @@ trainer = train_on_responses_only(
|
|
| 109 |
trainer_stats = trainer.train()
|
| 110 |
|
| 111 |
model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
|
| 112 |
-
model.save_pretrained_gguf("model", tokenizer, quantization_method="q8_0")
|
| 113 |
model.save_pretrained_gguf("model", tokenizer, quantization_method="f16")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
from unsloth import FastModel
|
|
|
|
|
|
|
| 4 |
from unsloth.chat_templates import get_chat_template, train_on_responses_only
|
| 5 |
import torch
|
| 6 |
+
from trl.trainer.sft_config import SFTConfig
|
| 7 |
+
from trl.trainer.sft_trainer import SFTTrainer
|
| 8 |
+
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
|
| 11 |
+
torch.backends.cudnn.benchmark = True
|
| 12 |
+
|
| 13 |
+
use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
| 14 |
+
use_fp16 = torch.cuda.is_available() and not use_bf16
|
| 15 |
|
| 16 |
max_seq_length = 2048
|
| 17 |
|
|
|
|
| 21 |
load_in_4bit=False, # 4 bit quantization to reduce memory
|
| 22 |
load_in_8bit=False, # [NEW!] A bit more accurate, uses 2x memory
|
| 23 |
full_finetuning=False, # [NEW!] We have full finetuning now!
|
| 24 |
+
dtype=torch.bfloat16 if use_bf16 else torch.float16 if use_fp16 else torch.float32,
|
| 25 |
# token = "hf_...", # use one if using gated models
|
| 26 |
)
|
| 27 |
|
| 28 |
+
if torch.cuda.is_available():
|
| 29 |
+
torch.cuda.empty_cache()
|
| 30 |
+
|
| 31 |
model = FastModel.get_peft_model(
|
| 32 |
model,
|
| 33 |
r=128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
|
|
|
| 55 |
chat_template="gemma3",
|
| 56 |
)
|
| 57 |
|
| 58 |
+
dataset = load_dataset("json", data_files="datasets/code/data.json", split="train")
|
| 59 |
+
# dataset = load_dataset("qmaru/gemma3-sms", split="train")
|
| 60 |
|
| 61 |
dataset = dataset.shuffle(seed=42)
|
| 62 |
|
|
|
|
| 108 |
remove_unused_columns=True,
|
| 109 |
dataloader_pin_memory=True,
|
| 110 |
dataloader_num_workers=4,
|
| 111 |
+
bf16=use_bf16,
|
| 112 |
+
fp16=use_fp16,
|
| 113 |
),
|
| 114 |
)
|
| 115 |
|
|
|
|
| 122 |
trainer_stats = trainer.train()
|
| 123 |
|
| 124 |
model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
|
|
|
|
| 125 |
model.save_pretrained_gguf("model", tokenizer, quantization_method="f16")
|
| 126 |
+
model.save_pretrained_gguf("model", tokenizer, quantization_method="q8_0")
|
| 127 |
+
# model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
|
| 128 |
+
|
| 129 |
+
os.system("./llama.cpp/build/bin/llama-quantize model.F16.gguf model.Q4_K_M.gguf 15")
|