File size: 2,617 Bytes
3468e66
 
 
 
 
 
 
 
 
 
 
 
c121008
3468e66
c121008
3468e66
 
 
 
689d2ea
 
 
 
 
fc4e123
79f3f4d
689d2ea
 
 
 
d71ac87
c121008
3468e66
 
 
c121008
3468e66
 
 
 
 
 
9828b8a
c121008
3468e66
 
 
 
8257d75
 
 
 
c121008
8257d75
 
 
 
d71ac87
3468e66
 
 
 
 
 
 
 
 
d71ac87
3468e66
 
 
c121008
 
3468e66
 
 
 
c121008
 
69fb596
79f3f4d
9828b8a
3468e66
 
 
 
 
 
 
 
8257d75
3468e66
 
 
 
 
 
 
c121008
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# /// script
# dependencies = ["trl", "peft", "bitsandbytes", "datasets", "transformers"]
# ///

from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os

# Configuration
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
DATASET_ID = "sunkencity/survival-instruct"
OUTPUT_MODEL_ID = "sunkencity/survival-expert-llama-3b"

# Load Dataset
dataset = load_dataset(DATASET_ID, split="train")

# SANITIZE DATASET
def filter_empty(example):
    return (
        example["instruction"] is not None 
        and example["response"] is not None
        and len(example["instruction"].strip()) > 0
        and len(example["response"].strip()) > 0
    )

dataset = dataset.filter(filter_empty)

# Load Model
# We keep 4-bit loading for memory efficiency, but compute in float32 to avoid kernel issues
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float32, # Changed to float32
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    use_cache=False,
    torch_dtype=torch.float32 # Changed to float32
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

# MANUAL FORMATTING
def format_row(example):
    instruction = example['instruction']
    response = example['response']
    text = f"Instruction: {instruction}\nResponse: {response}{tokenizer.eos_token}"
    return {"text": text}

dataset = dataset.map(format_row)

# LoRA
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)

# Args
training_args = SFTConfig(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=1, # Reduced batch size for FP32
    gradient_accumulation_steps=16, # Increased accumulation to compensate
    learning_rate=2e-4,
    logging_steps=10,
    push_to_hub=True,
    hub_model_id=OUTPUT_MODEL_ID,
    fp16=False, # Disable Mixed Precision
    bf16=False, # Disable BF16
    packing=False,
    max_length=1024,
    dataset_text_field="text"
)

# Trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    args=training_args,
    processing_class=tokenizer,
)

print("Starting training...")
trainer.train()

print("Pushing to hub...")
trainer.push_to_hub()
print("Done!")