shaikhsalman commited on
Commit
d678e13
·
verified ·
1 Parent(s): 82ebd41

Upload ai-ml/hf-finetuning/train_tulu3.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai-ml/hf-finetuning/train_tulu3.py +148 -0
ai-ml/hf-finetuning/train_tulu3.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Llama-3.1-8B-Instruct on allenai/tulu-3-sft-mixture (940K examples).
3
+
4
+ Recipe from Tulu 3 (Allen AI) - proven SOTA on Llama-3.1-8B:
5
+ - LR: 5e-6 (low for stability on 940K dataset)
6
+ - Effective batch: 128 (large batch for large dataset)
7
+ - Epochs: 2
8
+ - Max seq length: 4096
9
+ - LR schedule: linear with 0.03 warmup
10
+ - LoRA: r=256, alpha=16, all-linear (LoRA Without Regret)
11
+
12
+ Dataset: allenai/tulu-3-sft-mixture
13
+ - 940K examples from 19 curated sources
14
+ - Covers: math, code, IF, safety, science, chat
15
+ - Native messages format - zero preprocessing
16
+
17
+ Usage:
18
+ python train_tulu3.py
19
+ # Or with CLI args:
20
+ python train_tulu3.py --max_steps 100 # quick test
21
+ """
22
+
23
+ import argparse
24
+ import torch
25
+ from datasets import load_dataset
26
+ from peft import LoraConfig
27
+ from trl import SFTTrainer, SFTConfig
28
+ import trackio
29
+
30
+
31
+ def train(max_steps=None, push_hub=True, hub_model_id="shaikhsalman/llama-3.1-8b-tulu3-lora"):
32
+
33
+ # Trackio monitoring
34
+ trackio.init(
35
+ project="devsecops-ml",
36
+ name="sft-llama3.1-8b-tulu3",
37
+ config={
38
+ "model": "meta-llama/Llama-3.1-8B-Instruct",
39
+ "dataset": "allenai/tulu-3-sft-mixture",
40
+ "dataset_size": "940K",
41
+ "lora_r": 256,
42
+ "lora_alpha": 16,
43
+ "target_modules": "all-linear",
44
+ "learning_rate": 5e-6,
45
+ "effective_batch": 128,
46
+ "max_seq_length": 4096,
47
+ },
48
+ )
49
+
50
+ # Load dataset - already in messages format, zero prep needed
51
+ print("Loading allenai/tulu-3-sft-mixture (940K examples)...")
52
+ dataset = load_dataset("allenai/tulu-3-sft-mixture", split="train")
53
+ print(f"Loaded {len(dataset)} examples")
54
+ print(f"Sources: {set(dataset["source"])}")
55
+
56
+ # LoRA config (LoRA Without Regret: r=256, all-linear)
57
+ peft_config = LoraConfig(
58
+ r=256,
59
+ lora_alpha=16,
60
+ lora_dropout=0.05,
61
+ bias="none",
62
+ task_type="CAUSAL_LM",
63
+ target_modules="all-linear",
64
+ )
65
+
66
+ # Training config (Tulu 3 proven recipe)
67
+ training_args = SFTConfig(
68
+ # Output
69
+ output_dir="./output/llama3.1-8b-tulu3-lora",
70
+ push_to_hub=push_hub,
71
+ hub_model_id=hub_model_id,
72
+
73
+ # Model loading
74
+ model_init_kwargs={
75
+ "torch_dtype": torch.bfloat16,
76
+ "attn_implementation": "flash_attention_2",
77
+ },
78
+
79
+ # Tulu 3 recipe: LR 5e-6, batch 128, linear schedule
80
+ learning_rate=5e-6,
81
+ per_device_train_batch_size=4,
82
+ gradient_accumulation_steps=32, # 4 * 32 = 128 effective batch
83
+ num_train_epochs=2,
84
+ lr_scheduler_type="linear",
85
+ warmup_ratio=0.03,
86
+ max_seq_length=4096,
87
+
88
+ # LoRA Without Regret optimizations
89
+ packing=True,
90
+ packing_strategy="bfd_split",
91
+ gradient_checkpointing=True,
92
+ bf16=True,
93
+ assistant_only_loss=True,
94
+ eos_token="<|eot_id|>",
95
+
96
+ # Logging
97
+ logging_strategy="steps",
98
+ logging_steps=25,
99
+ logging_first_step=True,
100
+ report_to=["trackio"],
101
+ disable_tqdm=True,
102
+
103
+ # Checkpointing
104
+ save_strategy="steps",
105
+ save_steps=500,
106
+ save_total_limit=3,
107
+
108
+ # Optimization
109
+ optim="adamw_torch",
110
+ max_grad_norm=1.0,
111
+ )
112
+
113
+ # Quick test override
114
+ if max_steps:
115
+ training_args.max_steps = max_steps
116
+
117
+ # Trainer
118
+ trainer = SFTTrainer(
119
+ model="meta-llama/Llama-3.1-8B-Instruct",
120
+ train_dataset=dataset,
121
+ peft_config=peft_config,
122
+ args=training_args,
123
+ )
124
+
125
+ # Train
126
+ print("Starting training...")
127
+ trainer.train()
128
+
129
+ # Push to Hub
130
+ if push_hub:
131
+ trainer.push_to_hub()
132
+ print(f"Model pushed to: https://huggingface.co/{hub_model_id}")
133
+
134
+ trackio.finish()
135
+
136
+
137
+ if __name__ == "__main__":
138
+ parser = argparse.ArgumentParser()
139
+ parser.add_argument("--max_steps", type=int, default=None, help="Max steps (for quick test)")
140
+ parser.add_argument("--hub_model_id", type=str, default="shaikhsalman/llama-3.1-8b-tulu3-lora")
141
+ parser.add_argument("--no_push", action="store_true", help="Skip hub push")
142
+ args = parser.parse_args()
143
+
144
+ train(
145
+ max_steps=args.max_steps,
146
+ push_hub=not args.no_push,
147
+ hub_model_id=args.hub_model_id,
148
+ )