devsecops-platform / model /finetune_configurable.py
shaikhsalman's picture
refactor: merged structure - model at center, DevSecOps wrapped around it
9d4d5c7 verified
# =============================================================================
# HuggingFace Fine-Tuning Script β€” vNext Production Training
# =============================================================================
# Based on: "LoRA Without Regret" (Schulman et al., 2025)
# - LoRA matches full fine-tuning with correct configuration
# - Key: all-linear targets + r=256 + LR 2e-4 + batch < 32
#
# Datasets (ranked by quality):
# PRIMARY: allenai/tulu-3-sft-mixture (940K examples, 19 sources)
# REASONING: open-thoughts/OpenThoughts-114k (CoT traces)
# FALLBACK: HuggingFaceH4/ultrachat_200k (200K multi-turn chat)
# =============================================================================
import os
import torch
from dataclasses import dataclass, field
from typing import Optional, List
from datasets import load_dataset, concatenate_datasets
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import trackio
# ---------- Dataset Registry ----------
DATASET_REGISTRY = {
"tulu-3-sft": {
"name": "allenai/tulu-3-sft-mixture",
"split": "train",
"format": "messages", # Already conversational
"size": "~940K",
"quality": "BEST β€” 19 curated sources (math, code, IF, safety, science)",
},
"openthoughts-114k": {
"name": "open-thoughts/OpenThoughts-114k",
"split": "train",
"format": "conversations", # Needs conversion
"size": "~114K",
"quality": "EXCELLENT β€” reasoning CoT traces",
},
"ultrachat-200k": {
"name": "HuggingFaceH4/ultrachat_200k",
"split": "train_sft",
"format": "messages",
"size": "~200K",
"quality": "GOOD β€” multi-turn chat (baseline fallback)",
},
}
def convert_openthoughts_to_messages(example):
"""Convert OpenThoughts conversations format to standard messages."""
messages = []
if example.get("system"):
messages.append({"role": "system", "content": example["system"]})
for turn in example["conversations"]:
role = "user" if turn["from"] == "user" else "assistant"
messages.append({"role": role, "content": turn["value"]})
return {"messages": messages}
def load_and_prepare_dataset(dataset_key: str, max_samples: Optional[int] = None):
"""Load and format a dataset from the registry."""
info = DATASET_REGISTRY[dataset_key]
ds = load_dataset(info["name"], split=info["split"])
if max_samples:
ds = ds.select(range(min(max_samples, len(ds))))
if dataset_key == "openthoughts-114k":
remove_cols = [c for c in ds.column_names if c != "messages"]
ds = ds.map(
convert_openthoughts_to_messages,
remove_columns=remove_cols,
)
return ds
@dataclass
class FinetuneConfig:
"""Fine-tuning hyperparameters β€” vNext (LoRA Without Regret config)."""
model_name: str = "meta-llama/Llama-3.1-8B-Instruct"
dataset_key: str = "tulu-3-sft" # Options: tulu-3-sft, openthoughts-114k, ultrachat-200k
output_dir: str = "/output/models"
hub_model_id: str = "devsecops/finetuned-llama-v2"
# LoRA (LoRA Without Regret optimal config)
lora_r: int = 256 # r=256 β€” sufficient capacity for SFT-scale datasets
lora_alpha: int = 16 # alpha=16 β€” stable scaling
lora_dropout: float = 0.05
target_modules: str = "all-linear" # ALL linear layers, not just attention
# Training (LoRA Without Regret: batch < 32, LR = 2e-4)
num_train_epochs: int = 1 # 1 epoch sufficient for 940K dataset
per_device_train_batch_size: int = 2
gradient_accumulation_steps: int = 8 # effective batch = 16 (< 32!)
learning_rate: float = 2e-4 # 10x full FT rate
max_seq_length: int = 2048
warmup_ratio: float = 0.1
lr_scheduler_type: str = "cosine"
# Optimization
bf16: bool = True
gradient_checkpointing: bool = True
optim: str = "adamw_torch"
# Packing (LoRA Without Regret recommends packing=True)
packing: bool = True
packing_strategy: str = "bfd_split" # Preserves all tokens
# Loss
assistant_only_loss: bool = True # Only compute loss on assistant tokens
def finetune(config: FinetuneConfig):
"""Fine-tune a model with LoRA + SFT (vNext β€” LoRA Without Regret config)."""
# --- Trackio monitoring ---
trackio.init(
project="devsecops-ml",
name=f"sft-{config.model_name.split('/')[-1]}-{config.dataset_key}",
config=vars(config),
)
# --- Dataset (best available) ---
dataset = load_and_prepare_dataset(config.dataset_key)
print(f"Dataset: {config.dataset_key} ({len(dataset)} examples)")
# --- LoRA (LoRA Without Regret: all-linear, r=256) ---
peft_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=config.target_modules,
)
# --- SFT Config ---
sft_config = SFTConfig(
output_dir=config.output_dir,
num_train_epochs=config.num_train_epochs,
per_device_train_batch_size=config.per_device_train_batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learning_rate,
max_length=config.max_seq_length,
warmup_ratio=config.warmup_ratio,
lr_scheduler_type=config.lr_scheduler_type,
bf16=config.bf16,
gradient_checkpointing=config.gradient_checkpointing,
optim=config.optim,
packing=config.packing,
packing_strategy=config.packing_strategy,
assistant_only_loss=config.assistant_only_loss,
eos_token="<|eot_id|>",
logging_strategy="steps",
logging_steps=10,
logging_first_step=True,
save_strategy="steps",
save_steps=500,
save_total_limit=3,
push_to_hub=True,
hub_model_id=config.hub_model_id,
report_to="trackio",
disable_tqdm=True,
)
# --- Trainer (SFTTrainer handles model loading + PEFT) ---
trainer = SFTTrainer(
model=config.model_name,
train_dataset=dataset,
peft_config=peft_config,
args=sft_config,
)
# --- Train ---
trainer.train()
# --- Save ---
trainer.push_to_hub()
trackio.finish()
print(f"Model pushed to: https://huggingface.co/{config.hub_model_id}")
if __name__ == "__main__":
config = FinetuneConfig()
finetune(config)