wlabchoi commited on
Commit
ca616f3
·
verified ·
1 Parent(s): 145a6a7

Upload train_qwen3_wirelessmath.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_qwen3_wirelessmath.py +140 -0
train_qwen3_wirelessmath.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "transformers", "accelerate", "bitsandbytes"]
3
+ # ///
4
+
5
+ import os
6
+ import torch
7
+ from datasets import load_dataset
8
+ from peft import LoraConfig
9
+ from trl import SFTTrainer, SFTConfig
10
+ import trackio
11
+
12
+ # Disable tokenizer parallelism warning
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+
15
+ print("="*50)
16
+ print("Fine-tuning Qwen3-0.6B on WirelessMATHBench-XL")
17
+ print("Method: SFT with LoRA")
18
+ print("Dataset: Wireless Communications Math")
19
+ print("="*50)
20
+
21
+ # Load WirelessMATHBench-XL dataset
22
+ print("\nLoading WirelessMATHBench-XL dataset...")
23
+ train_dataset = load_dataset('XINLI1997/WirelessMATHBench-XL', split='train')
24
+ eval_dataset = load_dataset('XINLI1997/WirelessMATHBench-XL', split='test')
25
+
26
+ print(f"Train examples: {len(train_dataset)}")
27
+ print(f"Eval examples: {len(eval_dataset)}")
28
+
29
+ def format_for_sft(example):
30
+ """
31
+ Convert WirelessMATHBench-XL format to chat messages
32
+ Dataset has: prompt (pre-formatted), correct_answer, and other fields
33
+ """
34
+ # Use the pre-formatted prompt
35
+ prompt = example['prompt']
36
+ answer = example['correct_answer']
37
+
38
+ # Create chat format
39
+ messages = [
40
+ {'role': 'user', 'content': prompt},
41
+ {'role': 'assistant', 'content': answer}
42
+ ]
43
+
44
+ return {'messages': messages}
45
+
46
+ print("Preprocessing dataset...")
47
+ train_dataset = train_dataset.map(
48
+ format_for_sft,
49
+ remove_columns=train_dataset.column_names
50
+ )
51
+ eval_dataset = eval_dataset.map(
52
+ format_for_sft,
53
+ remove_columns=eval_dataset.column_names
54
+ )
55
+
56
+ # Configure LoRA for efficient fine-tuning
57
+ print("\nConfiguring LoRA...")
58
+ peft_config = LoraConfig(
59
+ r=16,
60
+ lora_alpha=32,
61
+ lora_dropout=0.05,
62
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
63
+ bias="none",
64
+ task_type="CAUSAL_LM"
65
+ )
66
+
67
+ # Configure SFT training
68
+ print("Configuring training arguments...")
69
+ training_args = SFTConfig(
70
+ output_dir="qwen3-wireless-math",
71
+
72
+ # Training hyperparameters
73
+ num_train_epochs=3,
74
+ per_device_train_batch_size=4,
75
+ per_device_eval_batch_size=4,
76
+ gradient_accumulation_steps=4, # Effective batch size = 16
77
+
78
+ # Optimization
79
+ learning_rate=2e-4,
80
+ lr_scheduler_type="cosine",
81
+ warmup_ratio=0.1,
82
+ weight_decay=0.01,
83
+
84
+ # Evaluation and saving
85
+ eval_strategy="steps",
86
+ eval_steps=100,
87
+ save_strategy="steps",
88
+ save_steps=200,
89
+ save_total_limit=3,
90
+
91
+ # Logging and monitoring
92
+ logging_steps=10,
93
+ report_to="trackio",
94
+ run_name="qwen3-0.6b-wireless-math",
95
+ project="wireless-math-finetuning",
96
+
97
+ # Memory optimization
98
+ gradient_checkpointing=False, # Disabled to avoid gradient computation issues
99
+ bf16=True,
100
+
101
+ # Hub integration
102
+ push_to_hub=True,
103
+ hub_model_id="wlabchoi/qwen3-0.6b-wireless-math",
104
+ hub_strategy="every_save",
105
+ hub_private_repo=False,
106
+
107
+ # Performance
108
+ dataloader_num_workers=0, # Avoid multiprocessing issues
109
+ remove_unused_columns=False,
110
+ )
111
+
112
+ # Initialize trainer
113
+ print("\nInitializing SFT Trainer...")
114
+ trainer = SFTTrainer(
115
+ model="Qwen/Qwen3-0.6B",
116
+ train_dataset=train_dataset,
117
+ eval_dataset=eval_dataset,
118
+ peft_config=peft_config,
119
+ args=training_args,
120
+ )
121
+
122
+ # Start training
123
+ print("\n" + "="*50)
124
+ print("Starting Fine-Tuning...")
125
+ print(f"Model: Qwen3-0.6B")
126
+ print(f"Dataset: WirelessMATHBench-XL")
127
+ print(f"Train: {len(train_dataset)} examples")
128
+ print(f"Eval: {len(eval_dataset)} examples")
129
+ print(f"Epochs: 3")
130
+ print("="*50 + "\n")
131
+
132
+ trainer.train()
133
+
134
+ # Push final model to Hub
135
+ print("\nPushing final model to Hub...")
136
+ trainer.push_to_hub(commit_message="Fine-tuning complete - Qwen3-0.6B on WirelessMATHBench-XL")
137
+
138
+ print("\n" + "="*50)
139
+ print("Fine-Tuning Completed Successfully!")
140
+ print("="*50)