training-scripts / train_qwen3_distillation.py
wlabchoi's picture
Upload train_qwen3_distillation.py with huggingface_hub
145a6a7 verified
# /// script
# dependencies = ["transformers>=4.40.0", "datasets", "torch", "accelerate", "peft>=0.7.0", "trackio", "bitsandbytes"]
# ///
import os
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
DataCollatorForSeq2Seq,
)
from peft import LoraConfig, get_peft_model
import trackio
from typing import Dict, Optional
import numpy as np
# Disable tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
print("="*50)
print("Knowledge Distillation: Qwen3-4B -> Qwen3-0.6B")
print("Method: MiniLLM (Reversed KLD + Teacher Sampling)")
print("Dataset: TeleQnA (Telecommunications)")
print("="*50)
# Configuration
TEACHER_MODEL = "Qwen/Qwen3-4B"
STUDENT_MODEL = "Qwen/Qwen3-0.6B"
TEMPERATURE = 2.0 # Temperature for softening distributions
ALPHA = 0.5 # Weight for distillation loss
# Load tokenizer
print(f"\nLoading tokenizer from {STUDENT_MODEL}...")
tokenizer = AutoTokenizer.from_pretrained(STUDENT_MODEL, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Load TeleQnA dataset
print("\nLoading TeleQnA dataset...")
raw_dataset = load_dataset('netop/TeleQnA', split='test')
def format_for_distillation(example):
"""Convert TeleQnA to chat format"""
choices_text = []
if 'choices' in example and example['choices']:
for i, choice in enumerate(example['choices'], 1):
choices_text.append(f'{i}. {choice}')
question = f"""{example['question']}
Options:
{chr(10).join(choices_text)}"""
explanation = example.get('explaination', '') or example.get('explanation', '')
answer = f"""{example['answer']}
Explanation: {explanation}"""
# Create prompt and completion
prompt = f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"
completion = f"{answer}<|im_end|>"
return {"prompt": prompt, "completion": completion}
print("Preprocessing dataset...")
dataset = raw_dataset.map(format_for_distillation, remove_columns=raw_dataset.column_names)
# Tokenize with prompt/completion structure
def tokenize_function(examples):
# Tokenize prompts (input)
prompt_encodings = tokenizer(
examples["prompt"],
truncation=True,
max_length=512,
padding=False,
)
# Tokenize completions (target)
completion_encodings = tokenizer(
examples["completion"],
truncation=True,
max_length=512,
padding=False,
)
# Combine
input_ids = [
p + c for p, c in zip(prompt_encodings["input_ids"], completion_encodings["input_ids"])
]
attention_mask = [
p + c for p, c in zip(prompt_encodings["attention_mask"], completion_encodings["attention_mask"])
]
# Labels: -100 for prompt (don't compute loss), actual tokens for completion
labels = [
[-100] * len(p) + c for p, c in zip(prompt_encodings["input_ids"], completion_encodings["input_ids"])
]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
print("Tokenizing dataset...")
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=["prompt", "completion"],
)
# Create train/eval split
print("Creating train/eval split...")
dataset_split = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]
print(f"Train examples: {len(train_dataset)}")
print(f"Eval examples: {len(eval_dataset)}")
# Load Teacher Model (frozen)
print(f"\nLoading teacher model: {TEACHER_MODEL}...")
teacher_model = AutoModelForCausalLM.from_pretrained(
TEACHER_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
teacher_model.eval()
for param in teacher_model.parameters():
param.requires_grad = False
print("βœ“ Teacher model loaded and frozen")
# Load Student Model (trainable with LoRA)
print(f"\nLoading student model: {STUDENT_MODEL}...")
student_model = AutoModelForCausalLM.from_pretrained(
STUDENT_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
# Apply LoRA
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
bias="none",
task_type="CAUSAL_LM"
)
student_model = get_peft_model(student_model, lora_config)
student_model.print_trainable_parameters()
# Verify trainable parameters
trainable_params = sum(p.numel() for p in student_model.parameters() if p.requires_grad)
assert trainable_params > 0, "No trainable parameters found!"
print(f"βœ“ Student model loaded with LoRA ({trainable_params:,} trainable params)")
# MiniLLM Distillation Trainer
class MiniLLMTrainer(Trainer):
"""
MiniLLM approach with:
1. Reversed KL Divergence: KL(student || teacher)
2. Teacher token sampling for training targets
"""
def __init__(self, *args, teacher_model=None, temperature=2.0, alpha=0.5, **kwargs):
super().__init__(*args, **kwargs)
self.teacher_model = teacher_model
self.temperature = temperature
self.alpha = alpha
self.use_teacher_sampling = True # MiniLLM uses teacher sampling
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
"""
MiniLLM Loss:
1. Sample tokens from teacher distribution
2. Compute Reversed KLD between student and teacher
3. Combine with cross-entropy loss
"""
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
labels = inputs.pop("labels")
# Get teacher predictions (no gradient)
with torch.no_grad():
teacher_outputs = self.teacher_model(
input_ids=input_ids,
attention_mask=attention_mask,
)
teacher_logits = teacher_outputs.logits
# Teacher token sampling (key part of MiniLLM)
if self.use_teacher_sampling and model.training:
# Sample from teacher's softmax distribution
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
# Sample tokens: [batch, seq_len]
sampled_tokens = torch.multinomial(
teacher_probs.view(-1, teacher_probs.size(-1)),
num_samples=1
).view(teacher_probs.size(0), teacher_probs.size(1))
# Replace labels with teacher-sampled tokens (except where labels are -100)
mask = labels != -100
labels = torch.where(mask, sampled_tokens, labels)
# Student forward pass
student_outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
)
student_logits = student_outputs.logits
# 1. Cross-Entropy Loss (with teacher-sampled tokens)
ce_loss = F.cross_entropy(
student_logits.view(-1, student_logits.size(-1)),
labels.view(-1),
ignore_index=-100,
reduction='mean'
)
# 2. Reversed KL Divergence: KL(student || teacher)
# This encourages student to cover all modes of teacher distribution
student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1)
student_probs = F.softmax(student_logits / self.temperature, dim=-1)
# Reversed KLD = sum(P_student * log(P_student / P_teacher))
reversed_kl = torch.sum(
student_probs * (student_log_probs - teacher_log_probs),
dim=-1
)
# Mask padding and non-target tokens
loss_mask = (labels != -100).float()
if loss_mask.dim() == 2:
# If labels are 2D, add dimension for broadcasting
loss_mask = loss_mask.unsqueeze(-1)
reversed_kl_masked = (reversed_kl * loss_mask.squeeze(-1)).sum() / (loss_mask.sum() + 1e-8)
# Scale by temperature squared
reversed_kl_masked = reversed_kl_masked * (self.temperature ** 2)
# Combined loss: alpha * Reversed_KL + (1-alpha) * CE
total_loss = self.alpha * reversed_kl_masked + (1 - self.alpha) * ce_loss
# Logging
if self.state.global_step % self.args.logging_steps == 0:
self.log({
"loss/total": total_loss.item(),
"loss/reversed_kl": reversed_kl_masked.item(),
"loss/cross_entropy": ce_loss.item(),
})
return (total_loss, student_outputs) if return_outputs else total_loss
# Training arguments
training_args = TrainingArguments(
output_dir="qwen3-0.6b-telecom-distilled",
# Training
num_train_epochs=3,
per_device_train_batch_size=2, # Increased from 1 (no gradient checkpointing)
per_device_eval_batch_size=2,
gradient_accumulation_steps=8, # Effective batch size = 16
# Optimization
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
weight_decay=0.01,
# Evaluation
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=200,
save_total_limit=3,
# Logging
logging_steps=10,
report_to="trackio",
run_name="qwen3-0.6b-telecom-minillm",
# Memory
gradient_checkpointing=False, # Disabled - conflicts with LoRA + dual model distillation
bf16=True,
# Hub
push_to_hub=True,
hub_model_id="wlabchoi/qwen3-0.6b-telecom-distilled",
hub_strategy="every_save",
hub_private_repo=False,
# Performance
dataloader_num_workers=0, # Avoid multiprocessing issues with tokenizers
remove_unused_columns=False,
)
# Data collator
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
model=student_model,
padding=True,
)
# Initialize trainer
print("\nInitializing MiniLLM Trainer...")
trainer = MiniLLMTrainer(
model=student_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
teacher_model=teacher_model,
temperature=TEMPERATURE,
alpha=ALPHA,
)
# Start training
print("\n" + "="*50)
print("Starting MiniLLM Knowledge Distillation...")
print(f"βœ“ Teacher: {TEACHER_MODEL} (frozen)")
print(f"βœ“ Student: {STUDENT_MODEL} (LoRA)")
print(f"βœ“ Method: Reversed KLD + Teacher Sampling")
print(f"βœ“ Temperature: {TEMPERATURE}")
print(f"βœ“ Alpha: {ALPHA}")
print(f"βœ“ Dataset: TeleQnA ({len(train_dataset)} train, {len(eval_dataset)} eval)")
print("="*50 + "\n")
trainer.train()
# Push final model
print("\nPushing distilled model to Hub...")
trainer.push_to_hub(commit_message="MiniLLM distillation: Qwen3-4B -> Qwen3-0.6B on TeleQnA")
print("\n" + "="*50)
print("βœ“ Knowledge Distillation Complete!")
print("="*50)