training-scripts / train_survival.py
sunkencity's picture
Upload train_survival.py with huggingface_hub
b5f89c4 verified
raw
history blame
2.63 kB
# /// script
# dependencies = ["trl", "peft", "bitsandbytes", "datasets", "transformers"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os
# Configuration
MODEL_ID = "Qwen/Qwen2.5-3B-Instruct"
DATASET_ID = "sunkencity/survival-instruct"
OUTPUT_MODEL_ID = "sunkencity/survival-expert-3b"
# Load Dataset
dataset = load_dataset(DATASET_ID, split="train")
# SANITIZE DATASET
def filter_empty(example):
return (
example["instruction"] is not None
and example["response"] is not None
and len(example["instruction"]) > 0
and len(example["response"]) > 0
)
dataset = dataset.filter(filter_empty)
# Load Model
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
use_cache=False
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
# LoRA
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)
# Args
training_args = SFTConfig(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
logging_steps=10,
push_to_hub=True,
hub_model_id=OUTPUT_MODEL_ID,
fp16=True,
packing=False,
max_length=1024
# Removed dataset_text_field="text" as it conflicted with formatting_func
)
def formatting_prompts_func(example):
output_texts = []
instructions = example['instruction']
responses = example['response']
for i in range(len(instructions)):
if i >= len(responses): break
instruction = instructions[i]
response = responses[i]
if not instruction or not response: continue
text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>"
output_texts.append(text)
return output_texts
# Trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
formatting_func=formatting_prompts_func,
args=training_args,
processing_class=tokenizer,
)
print("Starting training...")
trainer.train()
print("Pushing to hub...")
trainer.push_to_hub()
print("Done!")