File size: 4,263 Bytes
d678e13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0840157
d678e13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
Train Llama-3.1-8B-Instruct on allenai/tulu-3-sft-mixture (940K examples).

Recipe from Tulu 3 (Allen AI) - proven SOTA on Llama-3.1-8B:
  - LR: 5e-6 (low for stability on 940K dataset)
  - Effective batch: 128 (large batch for large dataset)
  - Epochs: 2
  - Max seq length: 4096
  - LR schedule: linear with 0.03 warmup
  - LoRA: r=256, alpha=16, all-linear (LoRA Without Regret)

Dataset: allenai/tulu-3-sft-mixture
  - 940K examples from 19 curated sources
  - Covers: math, code, IF, safety, science, chat
  - Native messages format - zero preprocessing

Usage:
  python train_tulu3.py
  # Or with CLI args:
  python train_tulu3.py --max_steps 100  # quick test
"""

import argparse
import torch
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import trackio


def train(max_steps=None, push_hub=True, hub_model_id="shaikhsalman/llama-3.1-8b-tulu3-lora"):

    # Trackio monitoring
    trackio.init(
        project="devsecops-ml",
        name="sft-llama3.1-8b-tulu3",
        config={
            "model": "meta-llama/Llama-3.1-8B-Instruct",
            "dataset": "allenai/tulu-3-sft-mixture",
            "dataset_size": "940K",
            "lora_r": 256,
            "lora_alpha": 16,
            "target_modules": "all-linear",
            "learning_rate": 5e-6,
            "effective_batch": 128,
            "max_seq_length": 4096,
        },
    )

    # Load dataset - already in messages format, zero prep needed
    print("Loading allenai/tulu-3-sft-mixture (940K examples)...")
    dataset = load_dataset("allenai/tulu-3-sft-mixture", split="train")
    print(f"Loaded {len(dataset)} examples")
    print(f"Sources: {set(dataset["source"])}")

    # LoRA config (LoRA Without Regret: r=256, all-linear)
    peft_config = LoraConfig(
        r=256,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules="all-linear",
    )

    # Training config (Tulu 3 proven recipe)
    training_args = SFTConfig(
        # Output
        output_dir="./output/llama3.1-8b-tulu3-lora",
        push_to_hub=push_hub,
        hub_model_id=hub_model_id,

        # Model loading
        model_init_kwargs={
            "torch_dtype": torch.bfloat16,
            "attn_implementation": "flash_attention_2",
        },

        # Tulu 3 recipe: LR 5e-6, batch 128, linear schedule
        learning_rate=5e-6,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=32,  # 4 * 32 = 128 effective batch
        num_train_epochs=2,
        lr_scheduler_type="linear",
        warmup_ratio=0.03,
        max_length=4096,

        # LoRA Without Regret optimizations
        packing=True,
        packing_strategy="bfd_split",
        gradient_checkpointing=True,
        bf16=True,
        assistant_only_loss=True,
        eos_token="<|eot_id|>",

        # Logging
        logging_strategy="steps",
        logging_steps=25,
        logging_first_step=True,
        report_to=["trackio"],
        disable_tqdm=True,

        # Checkpointing
        save_strategy="steps",
        save_steps=500,
        save_total_limit=3,

        # Optimization
        optim="adamw_torch",
        max_grad_norm=1.0,
    )

    # Quick test override
    if max_steps:
        training_args.max_steps = max_steps

    # Trainer
    trainer = SFTTrainer(
        model="meta-llama/Llama-3.1-8B-Instruct",
        train_dataset=dataset,
        peft_config=peft_config,
        args=training_args,
    )

    # Train
    print("Starting training...")
    trainer.train()

    # Push to Hub
    if push_hub:
        trainer.push_to_hub()
        print(f"Model pushed to: https://huggingface.co/{hub_model_id}")

    trackio.finish()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--max_steps", type=int, default=None, help="Max steps (for quick test)")
    parser.add_argument("--hub_model_id", type=str, default="shaikhsalman/llama-3.1-8b-tulu3-lora")
    parser.add_argument("--no_push", action="store_true", help="Skip hub push")
    args = parser.parse_args()

    train(
        max_steps=args.max_steps,
        push_hub=not args.no_push,
        hub_model_id=args.hub_model_id,
    )