temp-sft-script / run_sft_job.py
OliverSlivka's picture
Upload run_sft_job.py with huggingface_hub
3fdc9a0 verified
#!/usr/bin/env python3
# /// script
# dependencies = [
# "trl>=0.12.0",
# "peft>=0.11.1",
# "transformers>=4.41.2",
# "accelerate>=0.30.1",
# "datasets>=2.19.1",
# "bitsandbytes>=0.43.1",
# "trackio"
# ]
# ///
"""
Definitive SFT training script for Qwen/Qwen2.5-0.5B-Instruct on the corrected
itemsety dataset, loaded directly from a private GitHub repo.
This script implements 4-bit QLoRA as specified.
"""
import subprocess
import torch
from datasets import load_from_disk
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from trl import SFTTrainer
# --- 1. Load Dataset from GitHub ---
# Using the provided GitHub token for authentication
GIT_TOKEN = "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"
GIT_REPO_URL = f"https://{GIT_TOKEN}@github.com/oliversl1vka/itemsety-qwen-finetuning.git"
CLONE_PATH = "/tmp/itemsety-qwen-finetuning"
DATASET_PATH = f"{CLONE_PATH}/hf_dataset_enhanced"
print(f"πŸ“¦ Cloning private dataset from GitHub...")
subprocess.run(['git', 'clone', GIT_REPO_URL, CLONE_PATH], check=True)
print("βœ… Git clone complete.")
# Security: Remove the .git directory to avoid leaving the token in the filesystem
print("πŸ” Removing .git directory for security...")
subprocess.run(['rm', '-rf', f"{CLONE_PATH}/.git"], check=True)
print("βœ… .git directory removed.")
print(f"πŸ’Ύ Loading dataset from disk at {DATASET_PATH}...")
dataset = load_from_disk(DATASET_PATH)
train_dataset = dataset["train"]
eval_dataset = dataset["validation"]
# Verification assertions
assert len(train_dataset) == 88, f"Expected 88 train examples, got {len(train_dataset)}"
assert len(eval_dataset) == 10, f"Expected 10 val examples, got {len(eval_dataset)}"
assert 'messages' in train_dataset.column_names, "Missing 'messages' column"
print(f"βœ… Dataset loaded successfully. Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
# --- 2. Model and Tokenizer Configuration ---
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
# 4-bit QLoRA configuration (as specified)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
print(f"πŸ”₯ Loading model '{MODEL_ID}' with 4-bit QLoRA...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=quantization_config,
device_map="auto" # Let accelerate handle device mapping
)
model.config.use_cache = False # Recommended for fine-tuning
model.config.pretraining_tp = 1
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token # Set pad token to EOS token
tokenizer.padding_side = "right"
# --- 3. LoRA and Training Configuration ---
# LoRA config
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
)
# Training Arguments
training_args = TrainingArguments(
# Hub settings
output_dir="qwen2.5-0.5b-itemsety-qlora",
push_to_hub=True,
hub_model_id="OliverSlivka/qwen2.5-0.5b-itemsety-qlora-final",
hub_strategy="all_checkpoints",
# Training parameters
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4, # Common for QLoRA
optim="paged_adamw_8bit", # Use 8-bit AdamW optimizer
# Logging & checkpointing
logging_steps=5,
save_strategy="steps",
save_steps=20,
save_total_limit=2,
# Evaluation
eval_strategy="steps",
eval_steps=20,
# Optimization
warmup_ratio=0.03,
lr_scheduler_type="constant",
max_grad_norm=0.3,
max_steps=-1, # Train for num_train_epochs
# W&B or other reporting
report_to="trackio",
run_name="qwen-itemsety-qlora-run-final"
)
# --- 4. Initialize Trainer ---
print("🎯 Initializing SFTTrainer...")
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
args=training_args,
max_length=2048,
dataset_text_field="messages", # Use the messages column
packing=False # Do not pack sequences
)
# --- 5. Start Training ---
print("πŸš€ Starting training...")
trainer.train()
print("βœ… Training complete!")
print(f"πŸ’Ύ Model pushed to Hub at: https://huggingface.co/{training_args.hub_model_id}")
# To be safe, explicitly push the final adapter
print("... pushing final adapter one more time.")
trainer.push_to_hub()
print("βœ… All done.")