| """ |
| Train Llama-3.1-8B-Instruct on allenai/tulu-3-sft-mixture (940K examples). |
| |
| Recipe from Tulu 3 (Allen AI) - proven SOTA on Llama-3.1-8B: |
| - LR: 5e-6 (low for stability on 940K dataset) |
| - Effective batch: 128 (large batch for large dataset) |
| - Epochs: 2 |
| - Max seq length: 4096 |
| - LR schedule: linear with 0.03 warmup |
| - LoRA: r=256, alpha=16, all-linear (LoRA Without Regret) |
| |
| Dataset: allenai/tulu-3-sft-mixture |
| - 940K examples from 19 curated sources |
| - Covers: math, code, IF, safety, science, chat |
| - Native messages format - zero preprocessing |
| |
| Usage: |
| python train_tulu3.py |
| # Or with CLI args: |
| python train_tulu3.py --max_steps 100 # quick test |
| """ |
|
|
| import argparse |
| import torch |
| from datasets import load_dataset |
| from peft import LoraConfig |
| from trl import SFTTrainer, SFTConfig |
| import trackio |
|
|
|
|
| def train(max_steps=None, push_hub=True, hub_model_id="shaikhsalman/llama-3.1-8b-tulu3-lora"): |
|
|
| |
| trackio.init( |
| project="devsecops-ml", |
| name="sft-llama3.1-8b-tulu3", |
| config={ |
| "model": "meta-llama/Llama-3.1-8B-Instruct", |
| "dataset": "allenai/tulu-3-sft-mixture", |
| "dataset_size": "940K", |
| "lora_r": 256, |
| "lora_alpha": 16, |
| "target_modules": "all-linear", |
| "learning_rate": 5e-6, |
| "effective_batch": 128, |
| "max_seq_length": 4096, |
| }, |
| ) |
|
|
| |
| print("Loading allenai/tulu-3-sft-mixture (940K examples)...") |
| dataset = load_dataset("allenai/tulu-3-sft-mixture", split="train") |
| print(f"Loaded {len(dataset)} examples") |
| print(f"Sources: {set(dataset["source"])}") |
|
|
| |
| peft_config = LoraConfig( |
| r=256, |
| lora_alpha=16, |
| lora_dropout=0.05, |
| bias="none", |
| task_type="CAUSAL_LM", |
| target_modules="all-linear", |
| ) |
|
|
| |
| training_args = SFTConfig( |
| |
| output_dir="./output/llama3.1-8b-tulu3-lora", |
| push_to_hub=push_hub, |
| hub_model_id=hub_model_id, |
|
|
| |
| model_init_kwargs={ |
| "torch_dtype": torch.bfloat16, |
| "attn_implementation": "flash_attention_2", |
| }, |
|
|
| |
| learning_rate=5e-6, |
| per_device_train_batch_size=4, |
| gradient_accumulation_steps=32, |
| num_train_epochs=2, |
| lr_scheduler_type="linear", |
| warmup_ratio=0.03, |
| max_length=4096, |
|
|
| |
| packing=True, |
| packing_strategy="bfd_split", |
| gradient_checkpointing=True, |
| bf16=True, |
| assistant_only_loss=True, |
| eos_token="<|eot_id|>", |
|
|
| |
| logging_strategy="steps", |
| logging_steps=25, |
| logging_first_step=True, |
| report_to=["trackio"], |
| disable_tqdm=True, |
|
|
| |
| save_strategy="steps", |
| save_steps=500, |
| save_total_limit=3, |
|
|
| |
| optim="adamw_torch", |
| max_grad_norm=1.0, |
| ) |
|
|
| |
| if max_steps: |
| training_args.max_steps = max_steps |
|
|
| |
| trainer = SFTTrainer( |
| model="meta-llama/Llama-3.1-8B-Instruct", |
| train_dataset=dataset, |
| peft_config=peft_config, |
| args=training_args, |
| ) |
|
|
| |
| print("Starting training...") |
| trainer.train() |
|
|
| |
| if push_hub: |
| trainer.push_to_hub() |
| print(f"Model pushed to: https://huggingface.co/{hub_model_id}") |
|
|
| trackio.finish() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--max_steps", type=int, default=None, help="Max steps (for quick test)") |
| parser.add_argument("--hub_model_id", type=str, default="shaikhsalman/llama-3.1-8b-tulu3-lora") |
| parser.add_argument("--no_push", action="store_true", help="Skip hub push") |
| args = parser.parse_args() |
|
|
| train( |
| max_steps=args.max_steps, |
| push_hub=not args.no_push, |
| hub_model_id=args.hub_model_id, |
| ) |
|
|