tobil commited on
Commit
a91fb36
·
verified ·
1 Parent(s): e97d105

Upload train_1.7B_v2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_1.7B_v2.py +102 -0
train_1.7B_v2.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "peft>=0.7.0",
6
+ # "transformers>=4.45.0",
7
+ # "accelerate>=0.24.0",
8
+ # "trackio",
9
+ # "datasets",
10
+ # "bitsandbytes",
11
+ # ]
12
+ # ///
13
+ """
14
+ Improved Qwen3-1.7B training with best practices for larger models:
15
+ - Lower learning rate (1e-4 instead of 2e-4)
16
+ - Higher LoRA rank (32 instead of 16)
17
+ - More epochs (5 instead of 3)
18
+ - Weight decay for regularization
19
+ """
20
+
21
+ import trackio
22
+ from datasets import load_dataset
23
+ from peft import LoraConfig
24
+ from trl import SFTTrainer, SFTConfig
25
+
26
+ # Load dataset from Hub
27
+ print("Loading dataset...")
28
+ dataset = load_dataset("tobil/qmd-query-expansion-train", split="train")
29
+ print(f"Loaded {len(dataset)} examples")
30
+
31
+ # Create train/eval split
32
+ dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
33
+ train_dataset = dataset_split["train"]
34
+ eval_dataset = dataset_split["test"]
35
+ print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
36
+
37
+ # Training configuration - optimized for larger model
38
+ config = SFTConfig(
39
+ output_dir="qmd-query-expansion-1.7B-v2",
40
+ push_to_hub=True,
41
+ hub_model_id="tobil/qmd-query-expansion-1.7B-v2",
42
+ hub_strategy="every_save",
43
+
44
+ # Training parameters - lower LR, more epochs for larger model
45
+ num_train_epochs=5,
46
+ per_device_train_batch_size=2,
47
+ gradient_accumulation_steps=8,
48
+ learning_rate=1e-4, # Lowered from 2e-4
49
+ weight_decay=0.01, # Added regularization
50
+ max_length=512,
51
+
52
+ # Logging & checkpointing
53
+ logging_steps=25,
54
+ save_strategy="steps",
55
+ save_steps=200,
56
+ save_total_limit=3,
57
+
58
+ # Evaluation
59
+ eval_strategy="steps",
60
+ eval_steps=200,
61
+
62
+ # Optimization
63
+ warmup_ratio=0.1,
64
+ lr_scheduler_type="cosine",
65
+ bf16=True,
66
+ gradient_checkpointing=True,
67
+ gradient_checkpointing_kwargs={"use_reentrant": False},
68
+
69
+ # Monitoring
70
+ report_to="trackio",
71
+ project="qmd-query-expansion",
72
+ run_name="qwen3-1.7B-lora-v2",
73
+ )
74
+
75
+ # LoRA configuration - higher rank for better learning
76
+ peft_config = LoraConfig(
77
+ r=32, # Increased from 16
78
+ lora_alpha=64, # Increased from 32 (2x rank)
79
+ lora_dropout=0.05,
80
+ bias="none",
81
+ task_type="CAUSAL_LM",
82
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
83
+ )
84
+
85
+ # Initialize trainer
86
+ print("Initializing trainer with Qwen/Qwen3-1.7B...")
87
+ trainer = SFTTrainer(
88
+ model="Qwen/Qwen3-1.7B",
89
+ train_dataset=train_dataset,
90
+ eval_dataset=eval_dataset,
91
+ args=config,
92
+ peft_config=peft_config,
93
+ )
94
+
95
+ print("Starting training...")
96
+ trainer.train()
97
+
98
+ print("Pushing to Hub...")
99
+ trainer.push_to_hub()
100
+
101
+ trackio.finish()
102
+ print("Done! Model at: https://huggingface.co/tobil/qmd-query-expansion-1.7B-v2")