tobil commited on
Commit
ebad68f
·
verified ·
1 Parent(s): eda5b1d

Upload train_0.6B.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_0.6B.py +92 -0
train_0.6B.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "peft>=0.7.0",
6
+ # "transformers>=4.45.0",
7
+ # "accelerate>=0.24.0",
8
+ # "trackio",
9
+ # "datasets",
10
+ # "bitsandbytes",
11
+ # ]
12
+ # ///
13
+
14
+ import trackio
15
+ from datasets import load_dataset
16
+ from peft import LoraConfig
17
+ from trl import SFTTrainer, SFTConfig
18
+
19
+ # Load dataset from Hub
20
+ print("Loading dataset...")
21
+ dataset = load_dataset("tobil/qmd-query-expansion-train", split="train")
22
+ print(f"Loaded {len(dataset)} examples")
23
+
24
+ # Create train/eval split
25
+ dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
26
+ train_dataset = dataset_split["train"]
27
+ eval_dataset = dataset_split["test"]
28
+ print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
29
+
30
+ # Training configuration
31
+ config = SFTConfig(
32
+ output_dir="qmd-query-expansion-0.6B",
33
+ push_to_hub=True,
34
+ hub_model_id="tobil/qmd-query-expansion-0.6B",
35
+ hub_strategy="every_save",
36
+
37
+ # Training parameters
38
+ num_train_epochs=3,
39
+ per_device_train_batch_size=4,
40
+ gradient_accumulation_steps=4,
41
+ learning_rate=2e-4,
42
+ max_length=512,
43
+
44
+ # Logging & checkpointing
45
+ logging_steps=25,
46
+ save_strategy="steps",
47
+ save_steps=200,
48
+ save_total_limit=2,
49
+
50
+ # Evaluation
51
+ eval_strategy="steps",
52
+ eval_steps=200,
53
+
54
+ # Optimization
55
+ warmup_ratio=0.1,
56
+ lr_scheduler_type="cosine",
57
+ bf16=True,
58
+
59
+ # Monitoring
60
+ report_to="trackio",
61
+ project="qmd-query-expansion",
62
+ run_name="qwen3-0.6B-lora",
63
+ )
64
+
65
+ # LoRA configuration
66
+ peft_config = LoraConfig(
67
+ r=16,
68
+ lora_alpha=32,
69
+ lora_dropout=0.05,
70
+ bias="none",
71
+ task_type="CAUSAL_LM",
72
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
73
+ )
74
+
75
+ # Initialize trainer
76
+ print("Initializing trainer with Qwen/Qwen3-0.6B...")
77
+ trainer = SFTTrainer(
78
+ model="Qwen/Qwen3-0.6B",
79
+ train_dataset=train_dataset,
80
+ eval_dataset=eval_dataset,
81
+ args=config,
82
+ peft_config=peft_config,
83
+ )
84
+
85
+ print("Starting training...")
86
+ trainer.train()
87
+
88
+ print("Pushing to Hub...")
89
+ trainer.push_to_hub()
90
+
91
+ trackio.finish()
92
+ print("Done! Model at: https://huggingface.co/tobil/qmd-query-expansion-0.6B")