qmaru commited on
Commit
133f638
·
verified ·
1 Parent(s): 799745a

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +23 -7
train.py CHANGED
@@ -1,8 +1,17 @@
 
 
1
  from unsloth import FastModel
2
- from datasets import load_dataset
3
- from trl import SFTConfig, SFTTrainer
4
  from unsloth.chat_templates import get_chat_template, train_on_responses_only
5
  import torch
 
 
 
 
 
 
 
 
 
6
 
7
  max_seq_length = 2048
8
 
@@ -12,10 +21,13 @@ model, tokenizer = FastModel.from_pretrained(
12
  load_in_4bit=False, # 4 bit quantization to reduce memory
13
  load_in_8bit=False, # [NEW!] A bit more accurate, uses 2x memory
14
  full_finetuning=False, # [NEW!] We have full finetuning now!
15
- dtype=torch.bfloat16,
16
  # token = "hf_...", # use one if using gated models
17
  )
18
 
 
 
 
19
  model = FastModel.get_peft_model(
20
  model,
21
  r=128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
@@ -43,8 +55,8 @@ tokenizer = get_chat_template(
43
  chat_template="gemma3",
44
  )
45
 
46
- # dataset = load_dataset("json", data_files="dataset_min.json", split="train")
47
- dataset = load_dataset("qmaru/gemma3-sms", split="train")
48
 
49
  dataset = dataset.shuffle(seed=42)
50
 
@@ -96,7 +108,8 @@ trainer = SFTTrainer(
96
  remove_unused_columns=True,
97
  dataloader_pin_memory=True,
98
  dataloader_num_workers=4,
99
- bf16=True, # Use 16-bit precision for training
 
100
  ),
101
  )
102
 
@@ -109,5 +122,8 @@ trainer = train_on_responses_only(
109
  trainer_stats = trainer.train()
110
 
111
  model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
112
- model.save_pretrained_gguf("model", tokenizer, quantization_method="q8_0")
113
  model.save_pretrained_gguf("model", tokenizer, quantization_method="f16")
 
 
 
 
 
1
+ import os
2
+
3
  from unsloth import FastModel
 
 
4
  from unsloth.chat_templates import get_chat_template, train_on_responses_only
5
  import torch
6
+ from trl.trainer.sft_config import SFTConfig
7
+ from trl.trainer.sft_trainer import SFTTrainer
8
+
9
+ from datasets import load_dataset
10
+
11
+ torch.backends.cudnn.benchmark = True
12
+
13
+ use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
14
+ use_fp16 = torch.cuda.is_available() and not use_bf16
15
 
16
  max_seq_length = 2048
17
 
 
21
  load_in_4bit=False, # 4 bit quantization to reduce memory
22
  load_in_8bit=False, # [NEW!] A bit more accurate, uses 2x memory
23
  full_finetuning=False, # [NEW!] We have full finetuning now!
24
+ dtype=torch.bfloat16 if use_bf16 else torch.float16 if use_fp16 else torch.float32,
25
  # token = "hf_...", # use one if using gated models
26
  )
27
 
28
+ if torch.cuda.is_available():
29
+ torch.cuda.empty_cache()
30
+
31
  model = FastModel.get_peft_model(
32
  model,
33
  r=128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
 
55
  chat_template="gemma3",
56
  )
57
 
58
+ dataset = load_dataset("json", data_files="datasets/code/data.json", split="train")
59
+ # dataset = load_dataset("qmaru/gemma3-sms", split="train")
60
 
61
  dataset = dataset.shuffle(seed=42)
62
 
 
108
  remove_unused_columns=True,
109
  dataloader_pin_memory=True,
110
  dataloader_num_workers=4,
111
+ bf16=use_bf16,
112
+ fp16=use_fp16,
113
  ),
114
  )
115
 
 
122
  trainer_stats = trainer.train()
123
 
124
  model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
 
125
  model.save_pretrained_gguf("model", tokenizer, quantization_method="f16")
126
+ model.save_pretrained_gguf("model", tokenizer, quantization_method="q8_0")
127
+ # model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
128
+
129
+ os.system("./llama.cpp/build/bin/llama-quantize model.F16.gguf model.Q4_K_M.gguf 15")