""" Train Llama-3.1-8B-Instruct on allenai/tulu-3-sft-mixture (940K examples). Recipe from Tulu 3 (Allen AI) - proven SOTA on Llama-3.1-8B: - LR: 5e-6 (low for stability on 940K dataset) - Effective batch: 128 (large batch for large dataset) - Epochs: 2 - Max seq length: 4096 - LR schedule: linear with 0.03 warmup - LoRA: r=256, alpha=16, all-linear (LoRA Without Regret) Dataset: allenai/tulu-3-sft-mixture - 940K examples from 19 curated sources - Covers: math, code, IF, safety, science, chat - Native messages format - zero preprocessing Usage: python train_tulu3.py # Or with CLI args: python train_tulu3.py --max_steps 100 # quick test """ import argparse import torch from datasets import load_dataset from peft import LoraConfig from trl import SFTTrainer, SFTConfig import trackio def train(max_steps=None, push_hub=True, hub_model_id="shaikhsalman/llama-3.1-8b-tulu3-lora"): # Trackio monitoring trackio.init( project="devsecops-ml", name="sft-llama3.1-8b-tulu3", config={ "model": "meta-llama/Llama-3.1-8B-Instruct", "dataset": "allenai/tulu-3-sft-mixture", "dataset_size": "940K", "lora_r": 256, "lora_alpha": 16, "target_modules": "all-linear", "learning_rate": 5e-6, "effective_batch": 128, "max_seq_length": 4096, }, ) # Load dataset - already in messages format, zero prep needed print("Loading allenai/tulu-3-sft-mixture (940K examples)...") dataset = load_dataset("allenai/tulu-3-sft-mixture", split="train") print(f"Loaded {len(dataset)} examples") print(f"Sources: {set(dataset["source"])}") # LoRA config (LoRA Without Regret: r=256, all-linear) peft_config = LoraConfig( r=256, lora_alpha=16, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules="all-linear", ) # Training config (Tulu 3 proven recipe) training_args = SFTConfig( # Output output_dir="./output/llama3.1-8b-tulu3-lora", push_to_hub=push_hub, hub_model_id=hub_model_id, # Model loading model_init_kwargs={ "torch_dtype": torch.bfloat16, "attn_implementation": "flash_attention_2", }, # Tulu 3 recipe: LR 5e-6, batch 128, linear schedule learning_rate=5e-6, per_device_train_batch_size=4, gradient_accumulation_steps=32, # 4 * 32 = 128 effective batch num_train_epochs=2, lr_scheduler_type="linear", warmup_ratio=0.03, max_length=4096, # LoRA Without Regret optimizations packing=True, packing_strategy="bfd_split", gradient_checkpointing=True, bf16=True, assistant_only_loss=True, eos_token="<|eot_id|>", # Logging logging_strategy="steps", logging_steps=25, logging_first_step=True, report_to=["trackio"], disable_tqdm=True, # Checkpointing save_strategy="steps", save_steps=500, save_total_limit=3, # Optimization optim="adamw_torch", max_grad_norm=1.0, ) # Quick test override if max_steps: training_args.max_steps = max_steps # Trainer trainer = SFTTrainer( model="meta-llama/Llama-3.1-8B-Instruct", train_dataset=dataset, peft_config=peft_config, args=training_args, ) # Train print("Starting training...") trainer.train() # Push to Hub if push_hub: trainer.push_to_hub() print(f"Model pushed to: https://huggingface.co/{hub_model_id}") trackio.finish() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--max_steps", type=int, default=None, help="Max steps (for quick test)") parser.add_argument("--hub_model_id", type=str, default="shaikhsalman/llama-3.1-8b-tulu3-lora") parser.add_argument("--no_push", action="store_true", help="Skip hub push") args = parser.parse_args() train( max_steps=args.max_steps, push_hub=not args.no_push, hub_model_id=args.hub_model_id, )