feat: upgrade finetune.py — LoRA Without Regret + best datasets (tulu-3-sft, OpenThoughts-114k)
Browse files- ai-ml/hf-finetuning/finetune.py +104 -55
ai-ml/hf-finetuning/finetune.py
CHANGED
|
@@ -1,101 +1,146 @@
|
|
| 1 |
# =============================================================================
|
| 2 |
-
# HuggingFace Fine-Tuning Script —
|
| 3 |
# =============================================================================
|
| 4 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
# =============================================================================
|
| 6 |
|
| 7 |
import os
|
| 8 |
import torch
|
| 9 |
from dataclasses import dataclass, field
|
| 10 |
-
from typing import Optional
|
| 11 |
|
| 12 |
-
from datasets import load_dataset
|
| 13 |
from transformers import (
|
| 14 |
AutoModelForCausalLM,
|
| 15 |
AutoTokenizer,
|
| 16 |
BitsAndBytesConfig,
|
| 17 |
-
TrainingArguments,
|
| 18 |
)
|
| 19 |
-
from peft import LoraConfig
|
| 20 |
from trl import SFTTrainer, SFTConfig
|
| 21 |
import trackio
|
| 22 |
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
@dataclass
|
| 25 |
class FinetuneConfig:
|
| 26 |
-
"""Fine-tuning hyperparameters."""
|
| 27 |
model_name: str = "meta-llama/Llama-3.1-8B-Instruct"
|
| 28 |
-
|
| 29 |
output_dir: str = "/output/models"
|
| 30 |
-
hub_model_id: str = "devsecops/finetuned-llama"
|
| 31 |
|
| 32 |
-
# LoRA
|
| 33 |
-
lora_r: int =
|
| 34 |
-
lora_alpha: int =
|
| 35 |
lora_dropout: float = 0.05
|
|
|
|
| 36 |
|
| 37 |
-
# Training
|
| 38 |
-
num_train_epochs: int =
|
| 39 |
-
per_device_train_batch_size: int =
|
| 40 |
-
gradient_accumulation_steps: int = 8 # effective batch = 32
|
| 41 |
-
learning_rate: float = 2e-4
|
| 42 |
max_seq_length: int = 2048
|
| 43 |
warmup_ratio: float = 0.1
|
|
|
|
| 44 |
|
| 45 |
# Optimization
|
| 46 |
bf16: bool = True
|
| 47 |
gradient_checkpointing: bool = True
|
| 48 |
optim: str = "adamw_torch"
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def finetune(config: FinetuneConfig):
|
| 52 |
-
"""Fine-tune a model with LoRA + SFT."""
|
| 53 |
|
| 54 |
# --- Trackio monitoring ---
|
| 55 |
trackio.init(
|
| 56 |
project="devsecops-ml",
|
| 57 |
-
name=f"sft-{config.model_name.split('/')[-1]}",
|
| 58 |
config=vars(config),
|
| 59 |
)
|
| 60 |
|
| 61 |
-
# ---
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
bnb_4bit_quant_type="nf4",
|
| 65 |
-
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 66 |
-
bnb_4bit_use_double_quant=True,
|
| 67 |
-
)
|
| 68 |
-
|
| 69 |
-
# --- Load model ---
|
| 70 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 71 |
-
config.model_name,
|
| 72 |
-
trust_remote_code=True,
|
| 73 |
-
padding_side="right",
|
| 74 |
-
)
|
| 75 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 76 |
-
|
| 77 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 78 |
-
config.model_name,
|
| 79 |
-
quantization_config=bnb_config,
|
| 80 |
-
device_map="auto",
|
| 81 |
-
trust_remote_code=True,
|
| 82 |
-
torch_dtype=torch.bfloat16,
|
| 83 |
-
)
|
| 84 |
-
model = prepare_model_for_kbit_training(model)
|
| 85 |
|
| 86 |
-
# --- LoRA ---
|
| 87 |
-
|
| 88 |
r=config.lora_r,
|
| 89 |
lora_alpha=config.lora_alpha,
|
| 90 |
lora_dropout=config.lora_dropout,
|
| 91 |
bias="none",
|
| 92 |
task_type="CAUSAL_LM",
|
| 93 |
-
target_modules=
|
| 94 |
)
|
| 95 |
-
model = get_peft_model(model, lora_config)
|
| 96 |
-
|
| 97 |
-
# --- Dataset ---
|
| 98 |
-
dataset = load_dataset(config.dataset_name, split="train_sft[:5000]")
|
| 99 |
|
| 100 |
# --- SFT Config ---
|
| 101 |
sft_config = SFTConfig(
|
|
@@ -106,9 +151,14 @@ def finetune(config: FinetuneConfig):
|
|
| 106 |
learning_rate=config.learning_rate,
|
| 107 |
max_seq_length=config.max_seq_length,
|
| 108 |
warmup_ratio=config.warmup_ratio,
|
|
|
|
| 109 |
bf16=config.bf16,
|
| 110 |
gradient_checkpointing=config.gradient_checkpointing,
|
| 111 |
optim=config.optim,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
logging_strategy="steps",
|
| 113 |
logging_steps=10,
|
| 114 |
logging_first_step=True,
|
|
@@ -121,12 +171,12 @@ def finetune(config: FinetuneConfig):
|
|
| 121 |
disable_tqdm=True,
|
| 122 |
)
|
| 123 |
|
| 124 |
-
# --- Trainer ---
|
| 125 |
trainer = SFTTrainer(
|
| 126 |
-
model=
|
| 127 |
-
args=sft_config,
|
| 128 |
train_dataset=dataset,
|
| 129 |
-
|
|
|
|
| 130 |
)
|
| 131 |
|
| 132 |
# --- Train ---
|
|
@@ -135,7 +185,6 @@ def finetune(config: FinetuneConfig):
|
|
| 135 |
# --- Save ---
|
| 136 |
trainer.push_to_hub()
|
| 137 |
trackio.finish()
|
| 138 |
-
|
| 139 |
print(f"Model pushed to: https://huggingface.co/{config.hub_model_id}")
|
| 140 |
|
| 141 |
|
|
|
|
| 1 |
# =============================================================================
|
| 2 |
+
# HuggingFace Fine-Tuning Script — vNext Production Training
|
| 3 |
# =============================================================================
|
| 4 |
+
# Based on: "LoRA Without Regret" (Schulman et al., 2025)
|
| 5 |
+
# - LoRA matches full fine-tuning with correct configuration
|
| 6 |
+
# - Key: all-linear targets + r=256 + LR 2e-4 + batch < 32
|
| 7 |
+
#
|
| 8 |
+
# Datasets (ranked by quality):
|
| 9 |
+
# PRIMARY: allenai/tulu-3-sft-mixture (940K examples, 19 sources)
|
| 10 |
+
# REASONING: open-thoughts/OpenThoughts-114k (CoT traces)
|
| 11 |
+
# FALLBACK: HuggingFaceH4/ultrachat_200k (200K multi-turn chat)
|
| 12 |
# =============================================================================
|
| 13 |
|
| 14 |
import os
|
| 15 |
import torch
|
| 16 |
from dataclasses import dataclass, field
|
| 17 |
+
from typing import Optional, List
|
| 18 |
|
| 19 |
+
from datasets import load_dataset, concatenate_datasets
|
| 20 |
from transformers import (
|
| 21 |
AutoModelForCausalLM,
|
| 22 |
AutoTokenizer,
|
| 23 |
BitsAndBytesConfig,
|
|
|
|
| 24 |
)
|
| 25 |
+
from peft import LoraConfig
|
| 26 |
from trl import SFTTrainer, SFTConfig
|
| 27 |
import trackio
|
| 28 |
|
| 29 |
|
| 30 |
+
# ---------- Dataset Registry ----------
|
| 31 |
+
DATASET_REGISTRY = {
|
| 32 |
+
"tulu-3-sft": {
|
| 33 |
+
"name": "allenai/tulu-3-sft-mixture",
|
| 34 |
+
"split": "train",
|
| 35 |
+
"format": "messages", # Already conversational
|
| 36 |
+
"size": "~940K",
|
| 37 |
+
"quality": "BEST — 19 curated sources (math, code, IF, safety, science)",
|
| 38 |
+
},
|
| 39 |
+
"openthoughts-114k": {
|
| 40 |
+
"name": "open-thoughts/OpenThoughts-114k",
|
| 41 |
+
"split": "train",
|
| 42 |
+
"format": "conversations", # Needs conversion
|
| 43 |
+
"size": "~114K",
|
| 44 |
+
"quality": "EXCELLENT — reasoning CoT traces",
|
| 45 |
+
},
|
| 46 |
+
"ultrachat-200k": {
|
| 47 |
+
"name": "HuggingFaceH4/ultrachat_200k",
|
| 48 |
+
"split": "train_sft",
|
| 49 |
+
"format": "messages",
|
| 50 |
+
"size": "~200K",
|
| 51 |
+
"quality": "GOOD — multi-turn chat (baseline fallback)",
|
| 52 |
+
},
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def convert_openthoughts_to_messages(example):
|
| 57 |
+
"""Convert OpenThoughts conversations format to standard messages."""
|
| 58 |
+
messages = []
|
| 59 |
+
if example.get("system"):
|
| 60 |
+
messages.append({"role": "system", "content": example["system"]})
|
| 61 |
+
for turn in example["conversations"]:
|
| 62 |
+
role = "user" if turn["from"] == "user" else "assistant"
|
| 63 |
+
messages.append({"role": role, "content": turn["value"]})
|
| 64 |
+
return {"messages": messages}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def load_and_prepare_dataset(dataset_key: str, max_samples: Optional[int] = None):
|
| 68 |
+
"""Load and format a dataset from the registry."""
|
| 69 |
+
info = DATASET_REGISTRY[dataset_key]
|
| 70 |
+
ds = load_dataset(info["name"], split=info["split"])
|
| 71 |
+
|
| 72 |
+
if max_samples:
|
| 73 |
+
ds = ds.select(range(min(max_samples, len(ds))))
|
| 74 |
+
|
| 75 |
+
if dataset_key == "openthoughts-114k":
|
| 76 |
+
remove_cols = [c for c in ds.column_names if c != "messages"]
|
| 77 |
+
ds = ds.map(
|
| 78 |
+
convert_openthoughts_to_messages,
|
| 79 |
+
remove_columns=remove_cols,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
return ds
|
| 83 |
+
|
| 84 |
+
|
| 85 |
@dataclass
|
| 86 |
class FinetuneConfig:
|
| 87 |
+
"""Fine-tuning hyperparameters — vNext (LoRA Without Regret config)."""
|
| 88 |
model_name: str = "meta-llama/Llama-3.1-8B-Instruct"
|
| 89 |
+
dataset_key: str = "tulu-3-sft" # Options: tulu-3-sft, openthoughts-114k, ultrachat-200k
|
| 90 |
output_dir: str = "/output/models"
|
| 91 |
+
hub_model_id: str = "devsecops/finetuned-llama-v2"
|
| 92 |
|
| 93 |
+
# LoRA (LoRA Without Regret optimal config)
|
| 94 |
+
lora_r: int = 256 # r=256 — sufficient capacity for SFT-scale datasets
|
| 95 |
+
lora_alpha: int = 16 # alpha=16 — stable scaling
|
| 96 |
lora_dropout: float = 0.05
|
| 97 |
+
target_modules: str = "all-linear" # ALL linear layers, not just attention
|
| 98 |
|
| 99 |
+
# Training (LoRA Without Regret: batch < 32, LR = 2e-4)
|
| 100 |
+
num_train_epochs: int = 1 # 1 epoch sufficient for 940K dataset
|
| 101 |
+
per_device_train_batch_size: int = 2
|
| 102 |
+
gradient_accumulation_steps: int = 8 # effective batch = 16 (< 32!)
|
| 103 |
+
learning_rate: float = 2e-4 # 10x full FT rate
|
| 104 |
max_seq_length: int = 2048
|
| 105 |
warmup_ratio: float = 0.1
|
| 106 |
+
lr_scheduler_type: str = "cosine"
|
| 107 |
|
| 108 |
# Optimization
|
| 109 |
bf16: bool = True
|
| 110 |
gradient_checkpointing: bool = True
|
| 111 |
optim: str = "adamw_torch"
|
| 112 |
|
| 113 |
+
# Packing (LoRA Without Regret recommends packing=True)
|
| 114 |
+
packing: bool = True
|
| 115 |
+
packing_strategy: str = "bfd_split" # Preserves all tokens
|
| 116 |
+
|
| 117 |
+
# Loss
|
| 118 |
+
assistant_only_loss: bool = True # Only compute loss on assistant tokens
|
| 119 |
+
|
| 120 |
|
| 121 |
def finetune(config: FinetuneConfig):
|
| 122 |
+
"""Fine-tune a model with LoRA + SFT (vNext — LoRA Without Regret config)."""
|
| 123 |
|
| 124 |
# --- Trackio monitoring ---
|
| 125 |
trackio.init(
|
| 126 |
project="devsecops-ml",
|
| 127 |
+
name=f"sft-{config.model_name.split('/')[-1]}-{config.dataset_key}",
|
| 128 |
config=vars(config),
|
| 129 |
)
|
| 130 |
|
| 131 |
+
# --- Dataset (best available) ---
|
| 132 |
+
dataset = load_and_prepare_dataset(config.dataset_key)
|
| 133 |
+
print(f"Dataset: {config.dataset_key} ({len(dataset)} examples)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
+
# --- LoRA (LoRA Without Regret: all-linear, r=256) ---
|
| 136 |
+
peft_config = LoraConfig(
|
| 137 |
r=config.lora_r,
|
| 138 |
lora_alpha=config.lora_alpha,
|
| 139 |
lora_dropout=config.lora_dropout,
|
| 140 |
bias="none",
|
| 141 |
task_type="CAUSAL_LM",
|
| 142 |
+
target_modules=config.target_modules,
|
| 143 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
# --- SFT Config ---
|
| 146 |
sft_config = SFTConfig(
|
|
|
|
| 151 |
learning_rate=config.learning_rate,
|
| 152 |
max_seq_length=config.max_seq_length,
|
| 153 |
warmup_ratio=config.warmup_ratio,
|
| 154 |
+
lr_scheduler_type=config.lr_scheduler_type,
|
| 155 |
bf16=config.bf16,
|
| 156 |
gradient_checkpointing=config.gradient_checkpointing,
|
| 157 |
optim=config.optim,
|
| 158 |
+
packing=config.packing,
|
| 159 |
+
packing_strategy=config.packing_strategy,
|
| 160 |
+
assistant_only_loss=config.assistant_only_loss,
|
| 161 |
+
eos_token="<|eot_id|>",
|
| 162 |
logging_strategy="steps",
|
| 163 |
logging_steps=10,
|
| 164 |
logging_first_step=True,
|
|
|
|
| 171 |
disable_tqdm=True,
|
| 172 |
)
|
| 173 |
|
| 174 |
+
# --- Trainer (SFTTrainer handles model loading + PEFT) ---
|
| 175 |
trainer = SFTTrainer(
|
| 176 |
+
model=config.model_name,
|
|
|
|
| 177 |
train_dataset=dataset,
|
| 178 |
+
peft_config=peft_config,
|
| 179 |
+
args=sft_config,
|
| 180 |
)
|
| 181 |
|
| 182 |
# --- Train ---
|
|
|
|
| 185 |
# --- Save ---
|
| 186 |
trainer.push_to_hub()
|
| 187 |
trackio.finish()
|
|
|
|
| 188 |
print(f"Model pushed to: https://huggingface.co/{config.hub_model_id}")
|
| 189 |
|
| 190 |
|