wlabchoi commited on
Commit
e5ff53a
·
verified ·
1 Parent(s): f9c8c49

Upload train_qwen3_distillation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_qwen3_distillation.py +336 -0
train_qwen3_distillation.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["transformers>=4.40.0", "datasets", "torch", "accelerate", "peft>=0.7.0", "trackio", "bitsandbytes"]
3
+ # ///
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from datasets import load_dataset
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ Trainer,
12
+ TrainingArguments,
13
+ DataCollatorForSeq2Seq,
14
+ )
15
+ from peft import LoraConfig, get_peft_model
16
+ import trackio
17
+ from typing import Dict, Optional
18
+ import numpy as np
19
+
20
+ print("="*50)
21
+ print("Knowledge Distillation: Qwen3-4B -> Qwen3-0.6B")
22
+ print("Method: MiniLLM (Reversed KLD + Teacher Sampling)")
23
+ print("Dataset: TeleQnA (Telecommunications)")
24
+ print("="*50)
25
+
26
+ # Configuration
27
+ TEACHER_MODEL = "Qwen/Qwen3-4B"
28
+ STUDENT_MODEL = "Qwen/Qwen3-0.6B"
29
+ TEMPERATURE = 2.0 # Temperature for softening distributions
30
+ ALPHA = 0.5 # Weight for distillation loss
31
+
32
+ # Load tokenizer
33
+ print(f"\nLoading tokenizer from {STUDENT_MODEL}...")
34
+ tokenizer = AutoTokenizer.from_pretrained(STUDENT_MODEL, trust_remote_code=True)
35
+ tokenizer.pad_token = tokenizer.eos_token
36
+ tokenizer.padding_side = "right"
37
+
38
+ # Load TeleQnA dataset
39
+ print("\nLoading TeleQnA dataset...")
40
+ raw_dataset = load_dataset('netop/TeleQnA', split='test')
41
+
42
+ def format_for_distillation(example):
43
+ """Convert TeleQnA to chat format"""
44
+ choices_text = []
45
+ if 'choices' in example and example['choices']:
46
+ for i, choice in enumerate(example['choices'], 1):
47
+ choices_text.append(f'{i}. {choice}')
48
+
49
+ question = f"""{example['question']}
50
+
51
+ Options:
52
+ {chr(10).join(choices_text)}"""
53
+
54
+ explanation = example.get('explaination', '') or example.get('explanation', '')
55
+ answer = f"""{example['answer']}
56
+
57
+ Explanation: {explanation}"""
58
+
59
+ # Create prompt and completion
60
+ prompt = f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"
61
+ completion = f"{answer}<|im_end|>"
62
+
63
+ return {"prompt": prompt, "completion": completion}
64
+
65
+ print("Preprocessing dataset...")
66
+ dataset = raw_dataset.map(format_for_distillation, remove_columns=raw_dataset.column_names)
67
+
68
+ # Tokenize with prompt/completion structure
69
+ def tokenize_function(examples):
70
+ # Tokenize prompts (input)
71
+ prompt_encodings = tokenizer(
72
+ examples["prompt"],
73
+ truncation=True,
74
+ max_length=512,
75
+ padding=False,
76
+ )
77
+
78
+ # Tokenize completions (target)
79
+ completion_encodings = tokenizer(
80
+ examples["completion"],
81
+ truncation=True,
82
+ max_length=512,
83
+ padding=False,
84
+ )
85
+
86
+ # Combine
87
+ input_ids = [
88
+ p + c for p, c in zip(prompt_encodings["input_ids"], completion_encodings["input_ids"])
89
+ ]
90
+ attention_mask = [
91
+ p + c for p, c in zip(prompt_encodings["attention_mask"], completion_encodings["attention_mask"])
92
+ ]
93
+
94
+ # Labels: -100 for prompt (don't compute loss), actual tokens for completion
95
+ labels = [
96
+ [-100] * len(p) + c for p, c in zip(prompt_encodings["input_ids"], completion_encodings["input_ids"])
97
+ ]
98
+
99
+ return {
100
+ "input_ids": input_ids,
101
+ "attention_mask": attention_mask,
102
+ "labels": labels,
103
+ }
104
+
105
+ print("Tokenizing dataset...")
106
+ tokenized_dataset = dataset.map(
107
+ tokenize_function,
108
+ batched=True,
109
+ remove_columns=["prompt", "completion"],
110
+ )
111
+
112
+ # Create train/eval split
113
+ print("Creating train/eval split...")
114
+ dataset_split = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
115
+ train_dataset = dataset_split["train"]
116
+ eval_dataset = dataset_split["test"]
117
+
118
+ print(f"Train examples: {len(train_dataset)}")
119
+ print(f"Eval examples: {len(eval_dataset)}")
120
+
121
+ # Load Teacher Model (frozen)
122
+ print(f"\nLoading teacher model: {TEACHER_MODEL}...")
123
+ teacher_model = AutoModelForCausalLM.from_pretrained(
124
+ TEACHER_MODEL,
125
+ torch_dtype=torch.bfloat16,
126
+ device_map="auto",
127
+ trust_remote_code=True,
128
+ )
129
+ teacher_model.eval()
130
+ for param in teacher_model.parameters():
131
+ param.requires_grad = False
132
+ print("✓ Teacher model loaded and frozen")
133
+
134
+ # Load Student Model (trainable with LoRA)
135
+ print(f"\nLoading student model: {STUDENT_MODEL}...")
136
+ student_model = AutoModelForCausalLM.from_pretrained(
137
+ STUDENT_MODEL,
138
+ torch_dtype=torch.bfloat16,
139
+ device_map="auto",
140
+ trust_remote_code=True,
141
+ )
142
+
143
+ # Apply LoRA
144
+ lora_config = LoraConfig(
145
+ r=16,
146
+ lora_alpha=32,
147
+ lora_dropout=0.05,
148
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
149
+ bias="none",
150
+ task_type="CAUSAL_LM"
151
+ )
152
+ student_model = get_peft_model(student_model, lora_config)
153
+ student_model.print_trainable_parameters()
154
+ print("✓ Student model loaded with LoRA")
155
+
156
+ # MiniLLM Distillation Trainer
157
+ class MiniLLMTrainer(Trainer):
158
+ """
159
+ MiniLLM approach with:
160
+ 1. Reversed KL Divergence: KL(student || teacher)
161
+ 2. Teacher token sampling for training targets
162
+ """
163
+ def __init__(self, *args, teacher_model=None, temperature=2.0, alpha=0.5, **kwargs):
164
+ super().__init__(*args, **kwargs)
165
+ self.teacher_model = teacher_model
166
+ self.temperature = temperature
167
+ self.alpha = alpha
168
+ self.use_teacher_sampling = True # MiniLLM uses teacher sampling
169
+
170
+ def compute_loss(self, model, inputs, return_outputs=False):
171
+ """
172
+ MiniLLM Loss:
173
+ 1. Sample tokens from teacher distribution
174
+ 2. Compute Reversed KLD between student and teacher
175
+ 3. Combine with cross-entropy loss
176
+ """
177
+ input_ids = inputs["input_ids"]
178
+ attention_mask = inputs["attention_mask"]
179
+ labels = inputs.pop("labels")
180
+
181
+ # Get teacher predictions (no gradient)
182
+ with torch.no_grad():
183
+ teacher_outputs = self.teacher_model(
184
+ input_ids=input_ids,
185
+ attention_mask=attention_mask,
186
+ )
187
+ teacher_logits = teacher_outputs.logits
188
+
189
+ # Teacher token sampling (key part of MiniLLM)
190
+ if self.use_teacher_sampling and self.training:
191
+ # Sample from teacher's softmax distribution
192
+ teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
193
+ # Sample tokens: [batch, seq_len]
194
+ sampled_tokens = torch.multinomial(
195
+ teacher_probs.view(-1, teacher_probs.size(-1)),
196
+ num_samples=1
197
+ ).view(teacher_probs.size(0), teacher_probs.size(1))
198
+
199
+ # Replace labels with teacher-sampled tokens (except where labels are -100)
200
+ mask = labels != -100
201
+ labels = torch.where(mask, sampled_tokens, labels)
202
+
203
+ # Student forward pass
204
+ student_outputs = model(
205
+ input_ids=input_ids,
206
+ attention_mask=attention_mask,
207
+ )
208
+ student_logits = student_outputs.logits
209
+
210
+ # 1. Cross-Entropy Loss (with teacher-sampled tokens)
211
+ ce_loss = F.cross_entropy(
212
+ student_logits.view(-1, student_logits.size(-1)),
213
+ labels.view(-1),
214
+ ignore_index=-100,
215
+ reduction='mean'
216
+ )
217
+
218
+ # 2. Reversed KL Divergence: KL(student || teacher)
219
+ # This encourages student to cover all modes of teacher distribution
220
+ student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
221
+ teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1)
222
+ student_probs = F.softmax(student_logits / self.temperature, dim=-1)
223
+
224
+ # Reversed KLD = sum(P_student * log(P_student / P_teacher))
225
+ reversed_kl = torch.sum(
226
+ student_probs * (student_log_probs - teacher_log_probs),
227
+ dim=-1
228
+ )
229
+
230
+ # Mask padding and non-target tokens
231
+ loss_mask = (labels != -100).float()
232
+ if loss_mask.dim() == 2:
233
+ # If labels are 2D, add dimension for broadcasting
234
+ loss_mask = loss_mask.unsqueeze(-1)
235
+
236
+ reversed_kl_masked = (reversed_kl * loss_mask.squeeze(-1)).sum() / (loss_mask.sum() + 1e-8)
237
+
238
+ # Scale by temperature squared
239
+ reversed_kl_masked = reversed_kl_masked * (self.temperature ** 2)
240
+
241
+ # Combined loss: alpha * Reversed_KL + (1-alpha) * CE
242
+ total_loss = self.alpha * reversed_kl_masked + (1 - self.alpha) * ce_loss
243
+
244
+ # Logging
245
+ if self.state.global_step % self.args.logging_steps == 0:
246
+ self.log({
247
+ "loss/total": total_loss.item(),
248
+ "loss/reversed_kl": reversed_kl_masked.item(),
249
+ "loss/cross_entropy": ce_loss.item(),
250
+ })
251
+
252
+ return (total_loss, student_outputs) if return_outputs else total_loss
253
+
254
+ # Training arguments
255
+ training_args = TrainingArguments(
256
+ output_dir="qwen3-0.6b-telecom-distilled",
257
+
258
+ # Training
259
+ num_train_epochs=3,
260
+ per_device_train_batch_size=1,
261
+ per_device_eval_batch_size=1,
262
+ gradient_accumulation_steps=16,
263
+
264
+ # Optimization
265
+ learning_rate=2e-4,
266
+ lr_scheduler_type="cosine",
267
+ warmup_ratio=0.1,
268
+ weight_decay=0.01,
269
+
270
+ # Evaluation
271
+ eval_strategy="steps",
272
+ eval_steps=100,
273
+ save_strategy="steps",
274
+ save_steps=200,
275
+ save_total_limit=3,
276
+
277
+ # Logging
278
+ logging_steps=10,
279
+ report_to="trackio",
280
+ run_name="qwen3-0.6b-telecom-minillm",
281
+
282
+ # Memory
283
+ gradient_checkpointing=True,
284
+ bf16=True,
285
+
286
+ # Hub
287
+ push_to_hub=True,
288
+ hub_model_id="wlabchoi/qwen3-0.6b-telecom-distilled",
289
+ hub_strategy="every_save",
290
+ hub_private_repo=False,
291
+
292
+ # Performance
293
+ dataloader_num_workers=4,
294
+ remove_unused_columns=False,
295
+ )
296
+
297
+ # Data collator
298
+ data_collator = DataCollatorForSeq2Seq(
299
+ tokenizer=tokenizer,
300
+ model=student_model,
301
+ padding=True,
302
+ )
303
+
304
+ # Initialize trainer
305
+ print("\nInitializing MiniLLM Trainer...")
306
+ trainer = MiniLLMTrainer(
307
+ model=student_model,
308
+ args=training_args,
309
+ train_dataset=train_dataset,
310
+ eval_dataset=eval_dataset,
311
+ data_collator=data_collator,
312
+ teacher_model=teacher_model,
313
+ temperature=TEMPERATURE,
314
+ alpha=ALPHA,
315
+ )
316
+
317
+ # Start training
318
+ print("\n" + "="*50)
319
+ print("Starting MiniLLM Knowledge Distillation...")
320
+ print(f"✓ Teacher: {TEACHER_MODEL} (frozen)")
321
+ print(f"✓ Student: {STUDENT_MODEL} (LoRA)")
322
+ print(f"✓ Method: Reversed KLD + Teacher Sampling")
323
+ print(f"✓ Temperature: {TEMPERATURE}")
324
+ print(f"✓ Alpha: {ALPHA}")
325
+ print(f"✓ Dataset: TeleQnA ({len(train_dataset)} train, {len(eval_dataset)} eval)")
326
+ print("="*50 + "\n")
327
+
328
+ trainer.train()
329
+
330
+ # Push final model
331
+ print("\nPushing distilled model to Hub...")
332
+ trainer.push_to_hub(commit_message="MiniLLM distillation: Qwen3-4B -> Qwen3-0.6B on TeleQnA")
333
+
334
+ print("\n" + "="*50)
335
+ print("✓ Knowledge Distillation Complete!")
336
+ print("="*50)