sunkencity commited on
Commit
3468e66
·
verified ·
1 Parent(s): 20566c4

Upload train_survival.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_survival.py +89 -0
train_survival.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["trl", "peft", "bitsandbytes", "datasets", "transformers"]
3
+ # ///
4
+
5
+ from datasets import load_dataset
6
+ from peft import LoraConfig
7
+ from trl import SFTTrainer, SFTConfig
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
9
+ import torch
10
+ import os
11
+
12
+ # Configuration
13
+ MODEL_ID = "Qwen/Qwen2.5-3B-Instruct"
14
+ DATASET_ID = "sunkencity/survival-instruct"
15
+ OUTPUT_MODEL_ID = "sunkencity/survival-expert-3b"
16
+
17
+ # Load Dataset
18
+ dataset = load_dataset(DATASET_ID, split="train")
19
+
20
+ # Load Model with Quantization (for efficiency)
21
+ bnb_config = BitsAndBytesConfig(
22
+ load_in_4bit=True,
23
+ bnb_4bit_quant_type="nf4",
24
+ bnb_4bit_compute_dtype=torch.float16,
25
+ )
26
+
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ MODEL_ID,
29
+ quantization_config=bnb_config,
30
+ device_map="auto",
31
+ use_cache=False
32
+ )
33
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
34
+ tokenizer.pad_token = tokenizer.eos_token
35
+
36
+ # LoRA Configuration
37
+ peft_config = LoraConfig(
38
+ r=16,
39
+ lora_alpha=32,
40
+ lora_dropout=0.05,
41
+ bias="none",
42
+ task_type="CAUSAL_LM",
43
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
44
+ )
45
+
46
+ # Training Arguments
47
+ training_args = SFTConfig(
48
+ output_dir="./results",
49
+ num_train_epochs=3,
50
+ per_device_train_batch_size=4,
51
+ gradient_accumulation_steps=4,
52
+ learning_rate=2e-4,
53
+ logging_steps=10,
54
+ push_to_hub=True,
55
+ hub_model_id=OUTPUT_MODEL_ID,
56
+ fp16=True,
57
+ max_seq_length=1024,
58
+ dataset_text_field="text", # We need to format the data first if it's not in 'text'
59
+ packing=False
60
+ )
61
+
62
+ # Formatting function for SFT (Chat format)
63
+ def formatting_prompts_func(example):
64
+ output_texts = []
65
+ for i in range(len(example['instruction'])):
66
+ instruction = example['instruction'][i]
67
+ response = example['response'][i]
68
+
69
+ # Qwen/Llama chat template format
70
+ text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>"
71
+ output_texts.append(text)
72
+ return output_texts
73
+
74
+ # Trainer
75
+ trainer = SFTTrainer(
76
+ model=model,
77
+ train_dataset=dataset,
78
+ peft_config=peft_config,
79
+ formatting_func=formatting_prompts_func,
80
+ args=training_args,
81
+ tokenizer=tokenizer,
82
+ )
83
+
84
+ print("Starting training...")
85
+ trainer.train()
86
+
87
+ print("Pushing to hub...")
88
+ trainer.push_to_hub()
89
+ print("Done!")