training-scripts / train_survival.py
sunkencity's picture
Upload train_survival.py with huggingface_hub
c121008 verified
# /// script
# dependencies = ["trl", "peft", "bitsandbytes", "datasets", "transformers"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os
# Configuration
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
DATASET_ID = "sunkencity/survival-instruct"
OUTPUT_MODEL_ID = "sunkencity/survival-expert-llama-3b"
# Load Dataset
dataset = load_dataset(DATASET_ID, split="train")
# SANITIZE DATASET
def filter_empty(example):
return (
example["instruction"] is not None
and example["response"] is not None
and len(example["instruction"].strip()) > 0
and len(example["response"].strip()) > 0
)
dataset = dataset.filter(filter_empty)
# Load Model
# We keep 4-bit loading for memory efficiency, but compute in float32 to avoid kernel issues
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float32, # Changed to float32
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
use_cache=False,
torch_dtype=torch.float32 # Changed to float32
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
# MANUAL FORMATTING
def format_row(example):
instruction = example['instruction']
response = example['response']
text = f"Instruction: {instruction}\nResponse: {response}{tokenizer.eos_token}"
return {"text": text}
dataset = dataset.map(format_row)
# LoRA
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)
# Args
training_args = SFTConfig(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=1, # Reduced batch size for FP32
gradient_accumulation_steps=16, # Increased accumulation to compensate
learning_rate=2e-4,
logging_steps=10,
push_to_hub=True,
hub_model_id=OUTPUT_MODEL_ID,
fp16=False, # Disable Mixed Precision
bf16=False, # Disable BF16
packing=False,
max_length=1024,
dataset_text_field="text"
)
# Trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
args=training_args,
processing_class=tokenizer,
)
print("Starting training...")
trainer.train()
print("Pushing to hub...")
trainer.push_to_hub()
print("Done!")