shaikhsalman commited on
Commit
4019af4
·
verified ·
1 Parent(s): bae9038

feat: add CLI runner + training recipe docs

Browse files
Files changed (1) hide show
  1. ai-ml/hf-finetuning/run_finetune.py +67 -0
ai-ml/hf-finetuning/run_finetune.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # SFT Fine-Tuning — CLI Entry Point (LoRA Without Regret config)
3
+ # =============================================================================
4
+ # Usage:
5
+ # # Default: tulu-3-sft + Llama-3.1-8B
6
+ # python run_finetune.py
7
+ #
8
+ # # OpenThoughts reasoning dataset
9
+ # python run_finetune.py --dataset_key openthoughts-114k
10
+ #
11
+ # # Ultrachat fallback
12
+ # python run_finetune.py --dataset_key ultrachat-200k
13
+ #
14
+ # # Custom hub model ID
15
+ # python run_finetune.py --hub_model_id my-org/my-model-v2
16
+ # =============================================================================
17
+
18
+ import argparse
19
+ import sys
20
+ from finetune import FinetuneConfig, finetune, DATASET_REGISTRY
21
+
22
+
23
+ def main():
24
+ parser = argparse.ArgumentParser(description="SFT Fine-Tuning (LoRA Without Regret)")
25
+ parser.add_argument("--dataset_key", default="tulu-3-sft",
26
+ choices=list(DATASET_REGISTRY.keys()),
27
+ help="Dataset to train on")
28
+ parser.add_argument("--hub_model_id", default=None,
29
+ help="HuggingFace Hub model ID for push")
30
+ parser.add_argument("--num_train_epochs", type=int, default=None)
31
+ parser.add_argument("--learning_rate", type=float, default=None)
32
+ parser.add_argument("--lora_r", type=int, default=None)
33
+ parser.add_argument("--per_device_train_batch_size", type=int, default=None)
34
+ parser.add_argument("--max_seq_length", type=int, default=None)
35
+
36
+ args = parser.parse_args()
37
+
38
+ config = FinetuneConfig()
39
+ if args.dataset_key:
40
+ config.dataset_key = args.dataset_key
41
+ if args.hub_model_id:
42
+ config.hub_model_id = args.hub_model_id
43
+ if args.num_train_epochs:
44
+ config.num_train_epochs = args.num_train_epochs
45
+ if args.learning_rate:
46
+ config.learning_rate = args.learning_rate
47
+ if args.lora_r:
48
+ config.lora_r = args.lora_r
49
+ if args.per_device_train_batch_size:
50
+ config.per_device_train_batch_size = args.per_device_train_batch_size
51
+ if args.max_seq_length:
52
+ config.max_seq_length = args.max_seq_length
53
+
54
+ print(f"Config: model={config.model_name}")
55
+ print(f" dataset={config.dataset_key}")
56
+ print(f" lora_r={config.lora_r}, lora_alpha={config.lora_alpha}")
57
+ print(f" target_modules={config.target_modules}")
58
+ print(f" lr={config.learning_rate}, epochs={config.num_train_epochs}")
59
+ print(f" effective_batch={config.per_device_train_batch_size * config.gradient_accumulation_steps}")
60
+ print(f" packing={config.packing}, strategy={config.packing_strategy}")
61
+ print(f" assistant_only_loss={config.assistant_only_loss}")
62
+
63
+ finetune(config)
64
+
65
+
66
+ if __name__ == "__main__":
67
+ main()