ceperaltab's picture
Upload train.py with huggingface_hub
d0c87cc verified
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers>=4.44.2",
# "accelerate>=0.24.0",
# "bitsandbytes>=0.41.0",
# "datasets",
# "scipy",
# "hf_transfer",
# "rich",
# "trackio",
# ]
# ///
import torch
import os
import trackio
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# === CONFIGURATION ===
MODEL_NAME = "Qwen/Qwen3-8B" # Base model β€” fits locally on M2 Pro 16GB after fine-tuning
DATASET_NAME = "ceperaltab/diamond-vision-dataset"
OUTPUT_DIR = "diamond-vision-expert"
HF_USERNAME = "ceperaltab"
def main():
print("=" * 60)
print("Diamond Vision Expert β€” QLoRA Fine-tuning")
print(f"Base model : {MODEL_NAME}")
print(f"Dataset : {DATASET_NAME}")
print("=" * 60)
# Load dataset
print(f"πŸ“¦ Loading dataset: {DATASET_NAME}...")
dataset = load_dataset(DATASET_NAME, split="train")
print(f"βœ… Dataset loaded: {len(dataset)} examples")
# Train / eval split
print("πŸ”€ Creating train/eval split...")
dataset_split = dataset.train_test_split(test_size=0.05, seed=42)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]
print(f" Train: {len(train_dataset)} | Eval: {len(eval_dataset)}")
# Training config
config = SFTConfig(
output_dir=OUTPUT_DIR,
push_to_hub=True,
hub_model_id=f"{HF_USERNAME}/{OUTPUT_DIR}",
hub_strategy="every_save",
# Training
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
learning_rate=2e-4,
# NOTE: max_seq_length is NOT supported in SFTConfig (trl>=0.12.0) β€” removed
# Logging & checkpointing
logging_steps=10,
save_strategy="steps",
save_steps=500,
save_total_limit=2,
# Evaluation
eval_strategy="steps",
eval_steps=500,
# Optimization
warmup_ratio=0.03,
lr_scheduler_type="cosine",
gradient_checkpointing=True,
bf16=True, # A10G supports bf16
# Monitoring
report_to="trackio",
project="diamond-vision-training",
run_name="diamond-vision-qwen3-8b-v1",
)
# LoRA
peft_config = LoraConfig(
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
)
# 4-bit QLoRA quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Load model
print(f"πŸ”„ Loading base model: {MODEL_NAME}...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Train
print("🎯 Initializing trainer...")
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=config,
peft_config=peft_config,
)
print("πŸš€ Starting training...")
trainer.train()
print("πŸ’Ύ Pushing final adapter to Hub...")
trainer.push_to_hub()
trackio.finish()
print("βœ… Done! Adapter pushed to:", f"https://huggingface.co/{HF_USERNAME}/{OUTPUT_DIR}")
if __name__ == "__main__":
main()