File size: 2,797 Bytes
5199dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# /// script
# dependencies = ["trl", "peft", "bitsandbytes", "datasets", "transformers"]
# ///

from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os

# Configuration
MODEL_ID = "Qwen/Qwen2.5-32B-Instruct"
DATASET_ID = "sunkencity/survival-instruct"
OUTPUT_MODEL_ID = "sunkencity/survival-expert-qwen-32b"

# Load Dataset
dataset = load_dataset(DATASET_ID, split="train")

# SANITIZE DATASET
def filter_empty(example):
    return (
        example["instruction"] is not None 
        and example["response"] is not None
        and len(example["instruction"].strip()) > 0
        and len(example["response"].strip()) > 0
    )

dataset = dataset.filter(filter_empty)

# Load Model
# 4-bit quantization is essential for 32B on single A100 if we want decent batch size
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16, # Using bfloat16 for A100
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    use_cache=False,
    torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

# MANUAL FORMATTING
def format_row(example):
    instruction = example['instruction']
    response = example['response']
    # Qwen Chat Template
    # <|im_start|>user
    # {instruction}<|im_end|>
    # <|im_start|>assistant
    # {response}<|im_end|>
    text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>{tokenizer.eos_token}"
    return {"text": text}

dataset = dataset.map(format_row)

# LoRA
peft_config = LoraConfig(
    r=32, # Increased rank for larger model capability
    lora_alpha=64,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)

# Args
training_args = SFTConfig(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4, # A100 has 80GB, we can afford larger batches
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    logging_steps=5,
    push_to_hub=True,
    hub_model_id=OUTPUT_MODEL_ID,
    fp16=False,
    bf16=True, # Enable BF16 for A100
    packing=False,
    max_length=2048, # Increased context length for 32B
    dataset_text_field="text"
)

# Trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    args=training_args,
    processing_class=tokenizer,
)

print("Starting training...")
trainer.train()

print("Pushing to hub...")
trainer.push_to_hub()
print("Done!")