File size: 2,797 Bytes
5199dbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
# /// script
# dependencies = ["trl", "peft", "bitsandbytes", "datasets", "transformers"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os
# Configuration
MODEL_ID = "Qwen/Qwen2.5-32B-Instruct"
DATASET_ID = "sunkencity/survival-instruct"
OUTPUT_MODEL_ID = "sunkencity/survival-expert-qwen-32b"
# Load Dataset
dataset = load_dataset(DATASET_ID, split="train")
# SANITIZE DATASET
def filter_empty(example):
return (
example["instruction"] is not None
and example["response"] is not None
and len(example["instruction"].strip()) > 0
and len(example["response"].strip()) > 0
)
dataset = dataset.filter(filter_empty)
# Load Model
# 4-bit quantization is essential for 32B on single A100 if we want decent batch size
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, # Using bfloat16 for A100
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
use_cache=False,
torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
# MANUAL FORMATTING
def format_row(example):
instruction = example['instruction']
response = example['response']
# Qwen Chat Template
# <|im_start|>user
# {instruction}<|im_end|>
# <|im_start|>assistant
# {response}<|im_end|>
text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>{tokenizer.eos_token}"
return {"text": text}
dataset = dataset.map(format_row)
# LoRA
peft_config = LoraConfig(
r=32, # Increased rank for larger model capability
lora_alpha=64,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)
# Args
training_args = SFTConfig(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4, # A100 has 80GB, we can afford larger batches
gradient_accumulation_steps=4,
learning_rate=1e-4,
logging_steps=5,
push_to_hub=True,
hub_model_id=OUTPUT_MODEL_ID,
fp16=False,
bf16=True, # Enable BF16 for A100
packing=False,
max_length=2048, # Increased context length for 32B
dataset_text_field="text"
)
# Trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
args=training_args,
processing_class=tokenizer,
)
print("Starting training...")
trainer.train()
print("Pushing to hub...")
trainer.push_to_hub()
print("Done!")
|