sunkencity commited on
Commit
5199dbe
·
verified ·
1 Parent(s): dc8648d

Upload train_survival_32b.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_survival_32b.py +104 -0
train_survival_32b.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["trl", "peft", "bitsandbytes", "datasets", "transformers"]
3
+ # ///
4
+
5
+ from datasets import load_dataset
6
+ from peft import LoraConfig
7
+ from trl import SFTTrainer, SFTConfig
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
9
+ import torch
10
+ import os
11
+
12
+ # Configuration
13
+ MODEL_ID = "Qwen/Qwen2.5-32B-Instruct"
14
+ DATASET_ID = "sunkencity/survival-instruct"
15
+ OUTPUT_MODEL_ID = "sunkencity/survival-expert-qwen-32b"
16
+
17
+ # Load Dataset
18
+ dataset = load_dataset(DATASET_ID, split="train")
19
+
20
+ # SANITIZE DATASET
21
+ def filter_empty(example):
22
+ return (
23
+ example["instruction"] is not None
24
+ and example["response"] is not None
25
+ and len(example["instruction"].strip()) > 0
26
+ and len(example["response"].strip()) > 0
27
+ )
28
+
29
+ dataset = dataset.filter(filter_empty)
30
+
31
+ # Load Model
32
+ # 4-bit quantization is essential for 32B on single A100 if we want decent batch size
33
+ bnb_config = BitsAndBytesConfig(
34
+ load_in_4bit=True,
35
+ bnb_4bit_quant_type="nf4",
36
+ bnb_4bit_compute_dtype=torch.bfloat16, # Using bfloat16 for A100
37
+ )
38
+
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ MODEL_ID,
41
+ quantization_config=bnb_config,
42
+ device_map="auto",
43
+ use_cache=False,
44
+ torch_dtype=torch.bfloat16
45
+ )
46
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
47
+ tokenizer.pad_token = tokenizer.eos_token
48
+
49
+ # MANUAL FORMATTING
50
+ def format_row(example):
51
+ instruction = example['instruction']
52
+ response = example['response']
53
+ # Qwen Chat Template
54
+ # <|im_start|>user
55
+ # {instruction}<|im_end|>
56
+ # <|im_start|>assistant
57
+ # {response}<|im_end|>
58
+ text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>{tokenizer.eos_token}"
59
+ return {"text": text}
60
+
61
+ dataset = dataset.map(format_row)
62
+
63
+ # LoRA
64
+ peft_config = LoraConfig(
65
+ r=32, # Increased rank for larger model capability
66
+ lora_alpha=64,
67
+ lora_dropout=0.05,
68
+ bias="none",
69
+ task_type="CAUSAL_LM",
70
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
71
+ )
72
+
73
+ # Args
74
+ training_args = SFTConfig(
75
+ output_dir="./results",
76
+ num_train_epochs=3,
77
+ per_device_train_batch_size=4, # A100 has 80GB, we can afford larger batches
78
+ gradient_accumulation_steps=4,
79
+ learning_rate=1e-4,
80
+ logging_steps=5,
81
+ push_to_hub=True,
82
+ hub_model_id=OUTPUT_MODEL_ID,
83
+ fp16=False,
84
+ bf16=True, # Enable BF16 for A100
85
+ packing=False,
86
+ max_length=2048, # Increased context length for 32B
87
+ dataset_text_field="text"
88
+ )
89
+
90
+ # Trainer
91
+ trainer = SFTTrainer(
92
+ model=model,
93
+ train_dataset=dataset,
94
+ peft_config=peft_config,
95
+ args=training_args,
96
+ processing_class=tokenizer,
97
+ )
98
+
99
+ print("Starting training...")
100
+ trainer.train()
101
+
102
+ print("Pushing to hub...")
103
+ trainer.push_to_hub()
104
+ print("Done!")