File size: 4,263 Bytes
d678e13 0840157 d678e13 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | """
Train Llama-3.1-8B-Instruct on allenai/tulu-3-sft-mixture (940K examples).
Recipe from Tulu 3 (Allen AI) - proven SOTA on Llama-3.1-8B:
- LR: 5e-6 (low for stability on 940K dataset)
- Effective batch: 128 (large batch for large dataset)
- Epochs: 2
- Max seq length: 4096
- LR schedule: linear with 0.03 warmup
- LoRA: r=256, alpha=16, all-linear (LoRA Without Regret)
Dataset: allenai/tulu-3-sft-mixture
- 940K examples from 19 curated sources
- Covers: math, code, IF, safety, science, chat
- Native messages format - zero preprocessing
Usage:
python train_tulu3.py
# Or with CLI args:
python train_tulu3.py --max_steps 100 # quick test
"""
import argparse
import torch
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import trackio
def train(max_steps=None, push_hub=True, hub_model_id="shaikhsalman/llama-3.1-8b-tulu3-lora"):
# Trackio monitoring
trackio.init(
project="devsecops-ml",
name="sft-llama3.1-8b-tulu3",
config={
"model": "meta-llama/Llama-3.1-8B-Instruct",
"dataset": "allenai/tulu-3-sft-mixture",
"dataset_size": "940K",
"lora_r": 256,
"lora_alpha": 16,
"target_modules": "all-linear",
"learning_rate": 5e-6,
"effective_batch": 128,
"max_seq_length": 4096,
},
)
# Load dataset - already in messages format, zero prep needed
print("Loading allenai/tulu-3-sft-mixture (940K examples)...")
dataset = load_dataset("allenai/tulu-3-sft-mixture", split="train")
print(f"Loaded {len(dataset)} examples")
print(f"Sources: {set(dataset["source"])}")
# LoRA config (LoRA Without Regret: r=256, all-linear)
peft_config = LoraConfig(
r=256,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules="all-linear",
)
# Training config (Tulu 3 proven recipe)
training_args = SFTConfig(
# Output
output_dir="./output/llama3.1-8b-tulu3-lora",
push_to_hub=push_hub,
hub_model_id=hub_model_id,
# Model loading
model_init_kwargs={
"torch_dtype": torch.bfloat16,
"attn_implementation": "flash_attention_2",
},
# Tulu 3 recipe: LR 5e-6, batch 128, linear schedule
learning_rate=5e-6,
per_device_train_batch_size=4,
gradient_accumulation_steps=32, # 4 * 32 = 128 effective batch
num_train_epochs=2,
lr_scheduler_type="linear",
warmup_ratio=0.03,
max_length=4096,
# LoRA Without Regret optimizations
packing=True,
packing_strategy="bfd_split",
gradient_checkpointing=True,
bf16=True,
assistant_only_loss=True,
eos_token="<|eot_id|>",
# Logging
logging_strategy="steps",
logging_steps=25,
logging_first_step=True,
report_to=["trackio"],
disable_tqdm=True,
# Checkpointing
save_strategy="steps",
save_steps=500,
save_total_limit=3,
# Optimization
optim="adamw_torch",
max_grad_norm=1.0,
)
# Quick test override
if max_steps:
training_args.max_steps = max_steps
# Trainer
trainer = SFTTrainer(
model="meta-llama/Llama-3.1-8B-Instruct",
train_dataset=dataset,
peft_config=peft_config,
args=training_args,
)
# Train
print("Starting training...")
trainer.train()
# Push to Hub
if push_hub:
trainer.push_to_hub()
print(f"Model pushed to: https://huggingface.co/{hub_model_id}")
trackio.finish()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--max_steps", type=int, default=None, help="Max steps (for quick test)")
parser.add_argument("--hub_model_id", type=str, default="shaikhsalman/llama-3.1-8b-tulu3-lora")
parser.add_argument("--no_push", action="store_true", help="Skip hub push")
args = parser.parse_args()
train(
max_steps=args.max_steps,
push_hub=not args.no_push,
hub_model_id=args.hub_model_id,
)
|