devsecops-platform / model /train_tulu3.py
shaikhsalman's picture
refactor: merged structure - model at center, DevSecOps wrapped around it
9d4d5c7 verified
"""
Train Llama-3.1-8B-Instruct on allenai/tulu-3-sft-mixture (940K examples).
Recipe from Tulu 3 (Allen AI) - proven SOTA on Llama-3.1-8B:
- LR: 5e-6 (low for stability on 940K dataset)
- Effective batch: 128 (large batch for large dataset)
- Epochs: 2
- Max seq length: 4096
- LR schedule: linear with 0.03 warmup
- LoRA: r=256, alpha=16, all-linear (LoRA Without Regret)
Dataset: allenai/tulu-3-sft-mixture
- 940K examples from 19 curated sources
- Covers: math, code, IF, safety, science, chat
- Native messages format - zero preprocessing
Usage:
python train_tulu3.py
# Or with CLI args:
python train_tulu3.py --max_steps 100 # quick test
"""
import argparse
import torch
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import trackio
def train(max_steps=None, push_hub=True, hub_model_id="shaikhsalman/llama-3.1-8b-tulu3-lora"):
# Trackio monitoring
trackio.init(
project="devsecops-ml",
name="sft-llama3.1-8b-tulu3",
config={
"model": "meta-llama/Llama-3.1-8B-Instruct",
"dataset": "allenai/tulu-3-sft-mixture",
"dataset_size": "940K",
"lora_r": 256,
"lora_alpha": 16,
"target_modules": "all-linear",
"learning_rate": 5e-6,
"effective_batch": 128,
"max_seq_length": 4096,
},
)
# Load dataset - already in messages format, zero prep needed
print("Loading allenai/tulu-3-sft-mixture (940K examples)...")
dataset = load_dataset("allenai/tulu-3-sft-mixture", split="train")
print(f"Loaded {len(dataset)} examples")
print(f"Sources: {set(dataset["source"])}")
# LoRA config (LoRA Without Regret: r=256, all-linear)
peft_config = LoraConfig(
r=256,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules="all-linear",
)
# Training config (Tulu 3 proven recipe)
training_args = SFTConfig(
# Output
output_dir="./output/llama3.1-8b-tulu3-lora",
push_to_hub=push_hub,
hub_model_id=hub_model_id,
# Model loading
model_init_kwargs={
"torch_dtype": torch.bfloat16,
"attn_implementation": "flash_attention_2",
},
# Tulu 3 recipe: LR 5e-6, batch 128, linear schedule
learning_rate=5e-6,
per_device_train_batch_size=4,
gradient_accumulation_steps=32, # 4 * 32 = 128 effective batch
num_train_epochs=2,
lr_scheduler_type="linear",
warmup_ratio=0.03,
max_length=4096,
# LoRA Without Regret optimizations
packing=True,
packing_strategy="bfd_split",
gradient_checkpointing=True,
bf16=True,
assistant_only_loss=True,
eos_token="<|eot_id|>",
# Logging
logging_strategy="steps",
logging_steps=25,
logging_first_step=True,
report_to=["trackio"],
disable_tqdm=True,
# Checkpointing
save_strategy="steps",
save_steps=500,
save_total_limit=3,
# Optimization
optim="adamw_torch",
max_grad_norm=1.0,
)
# Quick test override
if max_steps:
training_args.max_steps = max_steps
# Trainer
trainer = SFTTrainer(
model="meta-llama/Llama-3.1-8B-Instruct",
train_dataset=dataset,
peft_config=peft_config,
args=training_args,
)
# Train
print("Starting training...")
trainer.train()
# Push to Hub
if push_hub:
trainer.push_to_hub()
print(f"Model pushed to: https://huggingface.co/{hub_model_id}")
trackio.finish()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--max_steps", type=int, default=None, help="Max steps (for quick test)")
parser.add_argument("--hub_model_id", type=str, default="shaikhsalman/llama-3.1-8b-tulu3-lora")
parser.add_argument("--no_push", action="store_true", help="Skip hub push")
args = parser.parse_args()
train(
max_steps=args.max_steps,
push_hub=not args.no_push,
hub_model_id=args.hub_model_id,
)