|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from datasets import load_dataset |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
DataCollatorForSeq2Seq, |
|
|
) |
|
|
from peft import LoraConfig, get_peft_model |
|
|
import trackio |
|
|
from typing import Dict, Optional |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
print("="*50) |
|
|
print("Knowledge Distillation: Qwen3-4B -> Qwen3-0.6B") |
|
|
print("Method: MiniLLM (Reversed KLD + Teacher Sampling)") |
|
|
print("Dataset: TeleQnA (Telecommunications)") |
|
|
print("="*50) |
|
|
|
|
|
|
|
|
TEACHER_MODEL = "Qwen/Qwen3-4B" |
|
|
STUDENT_MODEL = "Qwen/Qwen3-0.6B" |
|
|
TEMPERATURE = 2.0 |
|
|
ALPHA = 0.5 |
|
|
|
|
|
|
|
|
print(f"\nLoading tokenizer from {STUDENT_MODEL}...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(STUDENT_MODEL, trust_remote_code=True) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
tokenizer.padding_side = "right" |
|
|
|
|
|
|
|
|
print("\nLoading TeleQnA dataset...") |
|
|
raw_dataset = load_dataset('netop/TeleQnA', split='test') |
|
|
|
|
|
def format_for_distillation(example): |
|
|
"""Convert TeleQnA to chat format""" |
|
|
choices_text = [] |
|
|
if 'choices' in example and example['choices']: |
|
|
for i, choice in enumerate(example['choices'], 1): |
|
|
choices_text.append(f'{i}. {choice}') |
|
|
|
|
|
question = f"""{example['question']} |
|
|
|
|
|
Options: |
|
|
{chr(10).join(choices_text)}""" |
|
|
|
|
|
explanation = example.get('explaination', '') or example.get('explanation', '') |
|
|
answer = f"""{example['answer']} |
|
|
|
|
|
Explanation: {explanation}""" |
|
|
|
|
|
|
|
|
prompt = f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n" |
|
|
completion = f"{answer}<|im_end|>" |
|
|
|
|
|
return {"prompt": prompt, "completion": completion} |
|
|
|
|
|
print("Preprocessing dataset...") |
|
|
dataset = raw_dataset.map(format_for_distillation, remove_columns=raw_dataset.column_names) |
|
|
|
|
|
|
|
|
def tokenize_function(examples): |
|
|
|
|
|
prompt_encodings = tokenizer( |
|
|
examples["prompt"], |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
padding=False, |
|
|
) |
|
|
|
|
|
|
|
|
completion_encodings = tokenizer( |
|
|
examples["completion"], |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
padding=False, |
|
|
) |
|
|
|
|
|
|
|
|
input_ids = [ |
|
|
p + c for p, c in zip(prompt_encodings["input_ids"], completion_encodings["input_ids"]) |
|
|
] |
|
|
attention_mask = [ |
|
|
p + c for p, c in zip(prompt_encodings["attention_mask"], completion_encodings["attention_mask"]) |
|
|
] |
|
|
|
|
|
|
|
|
labels = [ |
|
|
[-100] * len(p) + c for p, c in zip(prompt_encodings["input_ids"], completion_encodings["input_ids"]) |
|
|
] |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"labels": labels, |
|
|
} |
|
|
|
|
|
print("Tokenizing dataset...") |
|
|
tokenized_dataset = dataset.map( |
|
|
tokenize_function, |
|
|
batched=True, |
|
|
remove_columns=["prompt", "completion"], |
|
|
) |
|
|
|
|
|
|
|
|
print("Creating train/eval split...") |
|
|
dataset_split = tokenized_dataset.train_test_split(test_size=0.1, seed=42) |
|
|
train_dataset = dataset_split["train"] |
|
|
eval_dataset = dataset_split["test"] |
|
|
|
|
|
print(f"Train examples: {len(train_dataset)}") |
|
|
print(f"Eval examples: {len(eval_dataset)}") |
|
|
|
|
|
|
|
|
print(f"\nLoading teacher model: {TEACHER_MODEL}...") |
|
|
teacher_model = AutoModelForCausalLM.from_pretrained( |
|
|
TEACHER_MODEL, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
teacher_model.eval() |
|
|
for param in teacher_model.parameters(): |
|
|
param.requires_grad = False |
|
|
print("β Teacher model loaded and frozen") |
|
|
|
|
|
|
|
|
print(f"\nLoading student model: {STUDENT_MODEL}...") |
|
|
student_model = AutoModelForCausalLM.from_pretrained( |
|
|
STUDENT_MODEL, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
|
|
|
lora_config = LoraConfig( |
|
|
r=16, |
|
|
lora_alpha=32, |
|
|
lora_dropout=0.05, |
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM" |
|
|
) |
|
|
student_model = get_peft_model(student_model, lora_config) |
|
|
student_model.print_trainable_parameters() |
|
|
|
|
|
|
|
|
trainable_params = sum(p.numel() for p in student_model.parameters() if p.requires_grad) |
|
|
assert trainable_params > 0, "No trainable parameters found!" |
|
|
print(f"β Student model loaded with LoRA ({trainable_params:,} trainable params)") |
|
|
|
|
|
|
|
|
class MiniLLMTrainer(Trainer): |
|
|
""" |
|
|
MiniLLM approach with: |
|
|
1. Reversed KL Divergence: KL(student || teacher) |
|
|
2. Teacher token sampling for training targets |
|
|
""" |
|
|
def __init__(self, *args, teacher_model=None, temperature=2.0, alpha=0.5, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.teacher_model = teacher_model |
|
|
self.temperature = temperature |
|
|
self.alpha = alpha |
|
|
self.use_teacher_sampling = True |
|
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
|
|
""" |
|
|
MiniLLM Loss: |
|
|
1. Sample tokens from teacher distribution |
|
|
2. Compute Reversed KLD between student and teacher |
|
|
3. Combine with cross-entropy loss |
|
|
""" |
|
|
input_ids = inputs["input_ids"] |
|
|
attention_mask = inputs["attention_mask"] |
|
|
labels = inputs.pop("labels") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
teacher_outputs = self.teacher_model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
) |
|
|
teacher_logits = teacher_outputs.logits |
|
|
|
|
|
|
|
|
if self.use_teacher_sampling and model.training: |
|
|
|
|
|
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1) |
|
|
|
|
|
sampled_tokens = torch.multinomial( |
|
|
teacher_probs.view(-1, teacher_probs.size(-1)), |
|
|
num_samples=1 |
|
|
).view(teacher_probs.size(0), teacher_probs.size(1)) |
|
|
|
|
|
|
|
|
mask = labels != -100 |
|
|
labels = torch.where(mask, sampled_tokens, labels) |
|
|
|
|
|
|
|
|
student_outputs = model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
) |
|
|
student_logits = student_outputs.logits |
|
|
|
|
|
|
|
|
ce_loss = F.cross_entropy( |
|
|
student_logits.view(-1, student_logits.size(-1)), |
|
|
labels.view(-1), |
|
|
ignore_index=-100, |
|
|
reduction='mean' |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1) |
|
|
teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1) |
|
|
student_probs = F.softmax(student_logits / self.temperature, dim=-1) |
|
|
|
|
|
|
|
|
reversed_kl = torch.sum( |
|
|
student_probs * (student_log_probs - teacher_log_probs), |
|
|
dim=-1 |
|
|
) |
|
|
|
|
|
|
|
|
loss_mask = (labels != -100).float() |
|
|
if loss_mask.dim() == 2: |
|
|
|
|
|
loss_mask = loss_mask.unsqueeze(-1) |
|
|
|
|
|
reversed_kl_masked = (reversed_kl * loss_mask.squeeze(-1)).sum() / (loss_mask.sum() + 1e-8) |
|
|
|
|
|
|
|
|
reversed_kl_masked = reversed_kl_masked * (self.temperature ** 2) |
|
|
|
|
|
|
|
|
total_loss = self.alpha * reversed_kl_masked + (1 - self.alpha) * ce_loss |
|
|
|
|
|
|
|
|
if self.state.global_step % self.args.logging_steps == 0: |
|
|
self.log({ |
|
|
"loss/total": total_loss.item(), |
|
|
"loss/reversed_kl": reversed_kl_masked.item(), |
|
|
"loss/cross_entropy": ce_loss.item(), |
|
|
}) |
|
|
|
|
|
return (total_loss, student_outputs) if return_outputs else total_loss |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir="qwen3-0.6b-telecom-distilled", |
|
|
|
|
|
|
|
|
num_train_epochs=3, |
|
|
per_device_train_batch_size=2, |
|
|
per_device_eval_batch_size=2, |
|
|
gradient_accumulation_steps=8, |
|
|
|
|
|
|
|
|
learning_rate=2e-4, |
|
|
lr_scheduler_type="cosine", |
|
|
warmup_ratio=0.1, |
|
|
weight_decay=0.01, |
|
|
|
|
|
|
|
|
eval_strategy="steps", |
|
|
eval_steps=100, |
|
|
save_strategy="steps", |
|
|
save_steps=200, |
|
|
save_total_limit=3, |
|
|
|
|
|
|
|
|
logging_steps=10, |
|
|
report_to="trackio", |
|
|
run_name="qwen3-0.6b-telecom-minillm", |
|
|
|
|
|
|
|
|
gradient_checkpointing=False, |
|
|
bf16=True, |
|
|
|
|
|
|
|
|
push_to_hub=True, |
|
|
hub_model_id="wlabchoi/qwen3-0.6b-telecom-distilled", |
|
|
hub_strategy="every_save", |
|
|
hub_private_repo=False, |
|
|
|
|
|
|
|
|
dataloader_num_workers=0, |
|
|
remove_unused_columns=False, |
|
|
) |
|
|
|
|
|
|
|
|
data_collator = DataCollatorForSeq2Seq( |
|
|
tokenizer=tokenizer, |
|
|
model=student_model, |
|
|
padding=True, |
|
|
) |
|
|
|
|
|
|
|
|
print("\nInitializing MiniLLM Trainer...") |
|
|
trainer = MiniLLMTrainer( |
|
|
model=student_model, |
|
|
args=training_args, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=eval_dataset, |
|
|
data_collator=data_collator, |
|
|
teacher_model=teacher_model, |
|
|
temperature=TEMPERATURE, |
|
|
alpha=ALPHA, |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "="*50) |
|
|
print("Starting MiniLLM Knowledge Distillation...") |
|
|
print(f"β Teacher: {TEACHER_MODEL} (frozen)") |
|
|
print(f"β Student: {STUDENT_MODEL} (LoRA)") |
|
|
print(f"β Method: Reversed KLD + Teacher Sampling") |
|
|
print(f"β Temperature: {TEMPERATURE}") |
|
|
print(f"β Alpha: {ALPHA}") |
|
|
print(f"β Dataset: TeleQnA ({len(train_dataset)} train, {len(eval_dataset)} eval)") |
|
|
print("="*50 + "\n") |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
print("\nPushing distilled model to Hub...") |
|
|
trainer.push_to_hub(commit_message="MiniLLM distillation: Qwen3-4B -> Qwen3-0.6B on TeleQnA") |
|
|
|
|
|
print("\n" + "="*50) |
|
|
print("β Knowledge Distillation Complete!") |
|
|
print("="*50) |
|
|
|