wlabchoi commited on
Commit
6eb1b0c
·
verified ·
1 Parent(s): ca616f3

Upload train_qwen3_wirelessmath.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_qwen3_wirelessmath.py +122 -23
train_qwen3_wirelessmath.py CHANGED
@@ -7,16 +7,18 @@ import torch
7
  from datasets import load_dataset
8
  from peft import LoraConfig
9
  from trl import SFTTrainer, SFTConfig
 
10
  import trackio
11
 
12
  # Disable tokenizer parallelism warning
13
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
 
15
- print("="*50)
16
  print("Fine-tuning Qwen3-0.6B on WirelessMATHBench-XL")
17
- print("Method: SFT with LoRA")
18
  print("Dataset: Wireless Communications Math")
19
- print("="*50)
 
20
 
21
  # Load WirelessMATHBench-XL dataset
22
  print("\nLoading WirelessMATHBench-XL dataset...")
@@ -26,24 +28,112 @@ eval_dataset = load_dataset('XINLI1997/WirelessMATHBench-XL', split='test')
26
  print(f"Train examples: {len(train_dataset)}")
27
  print(f"Eval examples: {len(eval_dataset)}")
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def format_for_sft(example):
30
- """
31
- Convert WirelessMATHBench-XL format to chat messages
32
- Dataset has: prompt (pre-formatted), correct_answer, and other fields
33
- """
34
- # Use the pre-formatted prompt
35
  prompt = example['prompt']
36
- answer = example['correct_answer']
37
 
38
- # Create chat format
39
  messages = [
40
  {'role': 'user', 'content': prompt},
41
- {'role': 'assistant', 'content': answer}
42
  ]
43
 
44
  return {'messages': messages}
45
 
46
- print("Preprocessing dataset...")
 
 
 
47
  train_dataset = train_dataset.map(
48
  format_for_sft,
49
  remove_columns=train_dataset.column_names
@@ -53,6 +143,8 @@ eval_dataset = eval_dataset.map(
53
  remove_columns=eval_dataset.column_names
54
  )
55
 
 
 
56
  # Configure LoRA for efficient fine-tuning
57
  print("\nConfiguring LoRA...")
58
  peft_config = LoraConfig(
@@ -91,7 +183,7 @@ training_args = SFTConfig(
91
  # Logging and monitoring
92
  logging_steps=10,
93
  report_to="trackio",
94
- run_name="qwen3-0.6b-wireless-math",
95
  project="wireless-math-finetuning",
96
 
97
  # Memory optimization
@@ -100,7 +192,7 @@ training_args = SFTConfig(
100
 
101
  # Hub integration
102
  push_to_hub=True,
103
- hub_model_id="wlabchoi/qwen3-0.6b-wireless-math",
104
  hub_strategy="every_save",
105
  hub_private_repo=False,
106
 
@@ -120,21 +212,28 @@ trainer = SFTTrainer(
120
  )
121
 
122
  # Start training
123
- print("\n" + "="*50)
124
- print("Starting Fine-Tuning...")
 
125
  print(f"Model: Qwen3-0.6B")
126
- print(f"Dataset: WirelessMATHBench-XL")
127
  print(f"Train: {len(train_dataset)} examples")
128
  print(f"Eval: {len(eval_dataset)} examples")
129
  print(f"Epochs: 3")
130
- print("="*50 + "\n")
 
131
 
132
  trainer.train()
133
 
134
  # Push final model to Hub
135
  print("\nPushing final model to Hub...")
136
- trainer.push_to_hub(commit_message="Fine-tuning complete - Qwen3-0.6B on WirelessMATHBench-XL")
137
-
138
- print("\n" + "="*50)
139
- print("Fine-Tuning Completed Successfully!")
140
- print("="*50)
 
 
 
 
 
 
7
  from datasets import load_dataset
8
  from peft import LoraConfig
9
  from trl import SFTTrainer, SFTConfig
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
  import trackio
12
 
13
  # Disable tokenizer parallelism warning
14
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
 
16
+ print("="*60)
17
  print("Fine-tuning Qwen3-0.6B on WirelessMATHBench-XL")
18
+ print("Method: SFT with LoRA + Reasoning Generation")
19
  print("Dataset: Wireless Communications Math")
20
+ print("Fix: Preserves <think></think> capability")
21
+ print("="*60)
22
 
23
  # Load WirelessMATHBench-XL dataset
24
  print("\nLoading WirelessMATHBench-XL dataset...")
 
28
  print(f"Train examples: {len(train_dataset)}")
29
  print(f"Eval examples: {len(eval_dataset)}")
30
 
31
+ # Load Teacher Model for Reasoning Generation (Preprocessing Step)
32
+ TEACHER_MODEL = "Qwen/Qwen2.5-3B-Instruct"
33
+ print(f"\n{'='*60}")
34
+ print(f"STEP 1: Generating Reasoning Steps (Preserves <think></think>)")
35
+ print(f"Teacher Model: {TEACHER_MODEL}")
36
+ print(f"{'='*60}")
37
+
38
+ teacher_tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL, trust_remote_code=True)
39
+ teacher_model = AutoModelForCausalLM.from_pretrained(
40
+ TEACHER_MODEL,
41
+ torch_dtype=torch.bfloat16,
42
+ device_map="auto",
43
+ trust_remote_code=True,
44
+ )
45
+ teacher_model.eval()
46
+ print("✓ Teacher model loaded for reasoning generation\n")
47
+
48
+ def generate_reasoning_batch(examples):
49
+ """Generate reasoning steps using teacher model (batch processing)"""
50
+ prompts = examples['prompt']
51
+ answers = examples['correct_answer']
52
+
53
+ # Create reasoning prompts
54
+ reasoning_prompts = []
55
+ for prompt in prompts:
56
+ reasoning_prompt = f"""<|im_start|>user
57
+ {prompt}
58
+
59
+ Solve step-by-step. Put reasoning in <think></think> tags, then give final answer.<|im_end|>
60
+ <|im_start|>assistant
61
+ <think>"""
62
+ reasoning_prompts.append(reasoning_prompt)
63
+
64
+ # Generate with teacher
65
+ inputs = teacher_tokenizer(
66
+ reasoning_prompts,
67
+ return_tensors="pt",
68
+ padding=True,
69
+ truncation=True,
70
+ max_length=512
71
+ ).to(teacher_model.device)
72
+
73
+ with torch.no_grad():
74
+ outputs = teacher_model.generate(
75
+ **inputs,
76
+ max_new_tokens=300,
77
+ do_sample=False,
78
+ pad_token_id=teacher_tokenizer.pad_token_id,
79
+ )
80
+
81
+ # Process responses
82
+ responses_with_reasoning = []
83
+ for i, output in enumerate(outputs):
84
+ generated_ids = output[inputs['input_ids'][i].shape[0]:]
85
+ response = teacher_tokenizer.decode(generated_ids, skip_special_tokens=False)
86
+
87
+ # Ensure format: <think>reasoning</think>\n\nanswer
88
+ if '</think>' not in response:
89
+ response = response.strip() + f"\n</think>\n\n{answers[i]}"
90
+ elif answers[i] not in response:
91
+ response = response.strip() + f"\n\n{answers[i]}"
92
+
93
+ responses_with_reasoning.append(response)
94
+
95
+ return {"reasoning_answer": responses_with_reasoning}
96
+
97
+ print("Generating reasoning for training set (this may take time)...")
98
+ train_dataset = train_dataset.map(
99
+ generate_reasoning_batch,
100
+ batched=True,
101
+ batch_size=4,
102
+ desc="Generating reasoning"
103
+ )
104
+
105
+ print("Generating reasoning for eval set...")
106
+ eval_dataset = eval_dataset.map(
107
+ generate_reasoning_batch,
108
+ batched=True,
109
+ batch_size=4,
110
+ desc="Generating reasoning"
111
+ )
112
+
113
+ print("✓ Reasoning generation complete!\n")
114
+
115
+ # Clean up teacher model to free memory
116
+ del teacher_model
117
+ del teacher_tokenizer
118
+ torch.cuda.empty_cache()
119
+ print("✓ Teacher model unloaded\n")
120
+
121
  def format_for_sft(example):
122
+ """Format augmented data for SFT training"""
 
 
 
 
123
  prompt = example['prompt']
124
+ answer_with_reasoning = example['reasoning_answer']
125
 
 
126
  messages = [
127
  {'role': 'user', 'content': prompt},
128
+ {'role': 'assistant', 'content': answer_with_reasoning}
129
  ]
130
 
131
  return {'messages': messages}
132
 
133
+ print(f"{'='*60}")
134
+ print(f"STEP 2: Formatting for SFT Training")
135
+ print(f"{'='*60}\n")
136
+
137
  train_dataset = train_dataset.map(
138
  format_for_sft,
139
  remove_columns=train_dataset.column_names
 
143
  remove_columns=eval_dataset.column_names
144
  )
145
 
146
+ print("✓ Dataset formatted with reasoning preserved")
147
+
148
  # Configure LoRA for efficient fine-tuning
149
  print("\nConfiguring LoRA...")
150
  peft_config = LoraConfig(
 
183
  # Logging and monitoring
184
  logging_steps=10,
185
  report_to="trackio",
186
+ run_name="qwen3-0.6b-wireless-math-reasoning",
187
  project="wireless-math-finetuning",
188
 
189
  # Memory optimization
 
192
 
193
  # Hub integration
194
  push_to_hub=True,
195
+ hub_model_id="wlabchoi/qwen3-0.6b-wireless-math-reasoning",
196
  hub_strategy="every_save",
197
  hub_private_repo=False,
198
 
 
212
  )
213
 
214
  # Start training
215
+ print("\n" + "="*60)
216
+ print("STEP 3: SFT Training on Reasoning-Augmented Data")
217
+ print("="*60)
218
  print(f"Model: Qwen3-0.6B")
219
+ print(f"Dataset: WirelessMATHBench-XL (with generated reasoning)")
220
  print(f"Train: {len(train_dataset)} examples")
221
  print(f"Eval: {len(eval_dataset)} examples")
222
  print(f"Epochs: 3")
223
+ print(f"Result: Model preserves <think></think> capability")
224
+ print("="*60 + "\n")
225
 
226
  trainer.train()
227
 
228
  # Push final model to Hub
229
  print("\nPushing final model to Hub...")
230
+ trainer.push_to_hub(commit_message="SFT complete - Qwen3-0.6B on WirelessMATH with reasoning preservation")
231
+
232
+ print("\n" + "="*60)
233
+ print("Fine-Tuning Complete - Reasoning Preserved!")
234
+ print("="*60)
235
+ print("Model now:")
236
+ print(" ✓ Knows wireless communications mathematics")
237
+ print(" ✓ Maintains <think></think> chain-of-thought")
238
+ print(" ✓ Shows reasoning steps before answers")
239
+ print("="*60)