tobil commited on
Commit
9f9c531
·
verified ·
1 Parent(s): b186bb2

Upload train_1.7B_sft.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_1.7B_sft.py +101 -0
train_1.7B_sft.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "peft>=0.7.0",
6
+ # "transformers>=4.36.0",
7
+ # "accelerate>=0.24.0",
8
+ # "trackio",
9
+ # ]
10
+ # ///
11
+
12
+ """
13
+ SFT training for Qwen3-1.7B query expansion model.
14
+ Dataset: tobil/qmd-query-expansion-train-v2
15
+ Output: tobil/qmd-query-expansion-1.7B-sft
16
+ """
17
+
18
+ import trackio
19
+ from datasets import load_dataset
20
+ from peft import LoraConfig
21
+ from trl import SFTTrainer, SFTConfig
22
+
23
+ # Load dataset
24
+ print("Loading dataset...")
25
+ dataset = load_dataset("tobil/qmd-query-expansion-train-v2", split="train")
26
+ print(f"Dataset loaded: {len(dataset)} examples")
27
+
28
+ # Create train/eval split
29
+ print("Creating train/eval split...")
30
+ dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
31
+ train_dataset = dataset_split["train"]
32
+ eval_dataset = dataset_split["test"]
33
+ print(f" Train: {len(train_dataset)} examples")
34
+ print(f" Eval: {len(eval_dataset)} examples")
35
+
36
+ # Training configuration
37
+ config = SFTConfig(
38
+ # Hub settings - use separate repo, not subfolder
39
+ output_dir="qmd-query-expansion-1.7B-sft",
40
+ push_to_hub=True,
41
+ hub_model_id="tobil/qmd-query-expansion-1.7B-sft",
42
+ hub_strategy="every_save",
43
+
44
+ # Training parameters
45
+ num_train_epochs=3,
46
+ per_device_train_batch_size=4,
47
+ gradient_accumulation_steps=4,
48
+ learning_rate=2e-4,
49
+ max_length=512,
50
+
51
+ # Logging & checkpointing
52
+ logging_steps=10,
53
+ save_strategy="steps",
54
+ save_steps=100,
55
+ save_total_limit=2,
56
+
57
+ # Evaluation
58
+ eval_strategy="steps",
59
+ eval_steps=100,
60
+
61
+ # Optimization
62
+ warmup_ratio=0.1,
63
+ lr_scheduler_type="cosine",
64
+
65
+ # Monitoring
66
+ report_to="trackio",
67
+ project="qmd-query-expansion",
68
+ run_name="1.7B-sft-v2",
69
+ )
70
+
71
+ # LoRA configuration
72
+ peft_config = LoraConfig(
73
+ r=16,
74
+ lora_alpha=32,
75
+ lora_dropout=0.05,
76
+ bias="none",
77
+ task_type="CAUSAL_LM",
78
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
79
+ )
80
+
81
+ # Initialize and train
82
+ print("Initializing trainer...")
83
+ trainer = SFTTrainer(
84
+ model="Qwen/Qwen3-1.7B",
85
+ train_dataset=train_dataset,
86
+ eval_dataset=eval_dataset,
87
+ args=config,
88
+ peft_config=peft_config,
89
+ )
90
+
91
+ print("Starting training...")
92
+ trainer.train()
93
+
94
+ print("Pushing to Hub...")
95
+ trainer.push_to_hub()
96
+
97
+ # Finish Trackio tracking
98
+ trackio.finish()
99
+
100
+ print("Complete! Model at: https://huggingface.co/tobil/qmd-query-expansion-1.7B-sft")
101
+ print("View metrics at: https://huggingface.co/spaces/tobil/trackio")