training-scripts / train_survival_32b.py
sunkencity's picture
Upload train_survival_32b.py with huggingface_hub
5199dbe verified
# /// script
# dependencies = ["trl", "peft", "bitsandbytes", "datasets", "transformers"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os
# Configuration
MODEL_ID = "Qwen/Qwen2.5-32B-Instruct"
DATASET_ID = "sunkencity/survival-instruct"
OUTPUT_MODEL_ID = "sunkencity/survival-expert-qwen-32b"
# Load Dataset
dataset = load_dataset(DATASET_ID, split="train")
# SANITIZE DATASET
def filter_empty(example):
return (
example["instruction"] is not None
and example["response"] is not None
and len(example["instruction"].strip()) > 0
and len(example["response"].strip()) > 0
)
dataset = dataset.filter(filter_empty)
# Load Model
# 4-bit quantization is essential for 32B on single A100 if we want decent batch size
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, # Using bfloat16 for A100
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
use_cache=False,
torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
# MANUAL FORMATTING
def format_row(example):
instruction = example['instruction']
response = example['response']
# Qwen Chat Template
# <|im_start|>user
# {instruction}<|im_end|>
# <|im_start|>assistant
# {response}<|im_end|>
text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>{tokenizer.eos_token}"
return {"text": text}
dataset = dataset.map(format_row)
# LoRA
peft_config = LoraConfig(
r=32, # Increased rank for larger model capability
lora_alpha=64,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)
# Args
training_args = SFTConfig(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4, # A100 has 80GB, we can afford larger batches
gradient_accumulation_steps=4,
learning_rate=1e-4,
logging_steps=5,
push_to_hub=True,
hub_model_id=OUTPUT_MODEL_ID,
fp16=False,
bf16=True, # Enable BF16 for A100
packing=False,
max_length=2048, # Increased context length for 32B
dataset_text_field="text"
)
# Trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
args=training_args,
processing_class=tokenizer,
)
print("Starting training...")
trainer.train()
print("Pushing to hub...")
trainer.push_to_hub()
print("Done!")