# /// script # dependencies = ["transformers>=4.40.0", "datasets", "torch", "accelerate", "peft>=0.7.0", "trackio", "bitsandbytes"] # /// import os import torch import torch.nn.functional as F from datasets import load_dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq, ) from peft import LoraConfig, get_peft_model import trackio from typing import Dict, Optional import numpy as np # Disable tokenizer parallelism warning os.environ["TOKENIZERS_PARALLELISM"] = "false" print("="*50) print("Knowledge Distillation: Qwen3-4B -> Qwen3-0.6B") print("Method: MiniLLM (Reversed KLD + Teacher Sampling)") print("Dataset: TeleQnA (Telecommunications)") print("="*50) # Configuration TEACHER_MODEL = "Qwen/Qwen3-4B" STUDENT_MODEL = "Qwen/Qwen3-0.6B" TEMPERATURE = 2.0 # Temperature for softening distributions ALPHA = 0.5 # Weight for distillation loss # Load tokenizer print(f"\nLoading tokenizer from {STUDENT_MODEL}...") tokenizer = AutoTokenizer.from_pretrained(STUDENT_MODEL, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" # Load TeleQnA dataset print("\nLoading TeleQnA dataset...") raw_dataset = load_dataset('netop/TeleQnA', split='test') def format_for_distillation(example): """Convert TeleQnA to chat format""" choices_text = [] if 'choices' in example and example['choices']: for i, choice in enumerate(example['choices'], 1): choices_text.append(f'{i}. {choice}') question = f"""{example['question']} Options: {chr(10).join(choices_text)}""" explanation = example.get('explaination', '') or example.get('explanation', '') answer = f"""{example['answer']} Explanation: {explanation}""" # Create prompt and completion prompt = f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n" completion = f"{answer}<|im_end|>" return {"prompt": prompt, "completion": completion} print("Preprocessing dataset...") dataset = raw_dataset.map(format_for_distillation, remove_columns=raw_dataset.column_names) # Tokenize with prompt/completion structure def tokenize_function(examples): # Tokenize prompts (input) prompt_encodings = tokenizer( examples["prompt"], truncation=True, max_length=512, padding=False, ) # Tokenize completions (target) completion_encodings = tokenizer( examples["completion"], truncation=True, max_length=512, padding=False, ) # Combine input_ids = [ p + c for p, c in zip(prompt_encodings["input_ids"], completion_encodings["input_ids"]) ] attention_mask = [ p + c for p, c in zip(prompt_encodings["attention_mask"], completion_encodings["attention_mask"]) ] # Labels: -100 for prompt (don't compute loss), actual tokens for completion labels = [ [-100] * len(p) + c for p, c in zip(prompt_encodings["input_ids"], completion_encodings["input_ids"]) ] return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } print("Tokenizing dataset...") tokenized_dataset = dataset.map( tokenize_function, batched=True, remove_columns=["prompt", "completion"], ) # Create train/eval split print("Creating train/eval split...") dataset_split = tokenized_dataset.train_test_split(test_size=0.1, seed=42) train_dataset = dataset_split["train"] eval_dataset = dataset_split["test"] print(f"Train examples: {len(train_dataset)}") print(f"Eval examples: {len(eval_dataset)}") # Load Teacher Model (frozen) print(f"\nLoading teacher model: {TEACHER_MODEL}...") teacher_model = AutoModelForCausalLM.from_pretrained( TEACHER_MODEL, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) teacher_model.eval() for param in teacher_model.parameters(): param.requires_grad = False print("✓ Teacher model loaded and frozen") # Load Student Model (trainable with LoRA) print(f"\nLoading student model: {STUDENT_MODEL}...") student_model = AutoModelForCausalLM.from_pretrained( STUDENT_MODEL, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) # Apply LoRA lora_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], bias="none", task_type="CAUSAL_LM" ) student_model = get_peft_model(student_model, lora_config) student_model.print_trainable_parameters() # Verify trainable parameters trainable_params = sum(p.numel() for p in student_model.parameters() if p.requires_grad) assert trainable_params > 0, "No trainable parameters found!" print(f"✓ Student model loaded with LoRA ({trainable_params:,} trainable params)") # MiniLLM Distillation Trainer class MiniLLMTrainer(Trainer): """ MiniLLM approach with: 1. Reversed KL Divergence: KL(student || teacher) 2. Teacher token sampling for training targets """ def __init__(self, *args, teacher_model=None, temperature=2.0, alpha=0.5, **kwargs): super().__init__(*args, **kwargs) self.teacher_model = teacher_model self.temperature = temperature self.alpha = alpha self.use_teacher_sampling = True # MiniLLM uses teacher sampling def compute_loss(self, model, inputs, return_outputs=False, **kwargs): """ MiniLLM Loss: 1. Sample tokens from teacher distribution 2. Compute Reversed KLD between student and teacher 3. Combine with cross-entropy loss """ input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] labels = inputs.pop("labels") # Get teacher predictions (no gradient) with torch.no_grad(): teacher_outputs = self.teacher_model( input_ids=input_ids, attention_mask=attention_mask, ) teacher_logits = teacher_outputs.logits # Teacher token sampling (key part of MiniLLM) if self.use_teacher_sampling and model.training: # Sample from teacher's softmax distribution teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1) # Sample tokens: [batch, seq_len] sampled_tokens = torch.multinomial( teacher_probs.view(-1, teacher_probs.size(-1)), num_samples=1 ).view(teacher_probs.size(0), teacher_probs.size(1)) # Replace labels with teacher-sampled tokens (except where labels are -100) mask = labels != -100 labels = torch.where(mask, sampled_tokens, labels) # Student forward pass student_outputs = model( input_ids=input_ids, attention_mask=attention_mask, ) student_logits = student_outputs.logits # 1. Cross-Entropy Loss (with teacher-sampled tokens) ce_loss = F.cross_entropy( student_logits.view(-1, student_logits.size(-1)), labels.view(-1), ignore_index=-100, reduction='mean' ) # 2. Reversed KL Divergence: KL(student || teacher) # This encourages student to cover all modes of teacher distribution student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1) teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1) student_probs = F.softmax(student_logits / self.temperature, dim=-1) # Reversed KLD = sum(P_student * log(P_student / P_teacher)) reversed_kl = torch.sum( student_probs * (student_log_probs - teacher_log_probs), dim=-1 ) # Mask padding and non-target tokens loss_mask = (labels != -100).float() if loss_mask.dim() == 2: # If labels are 2D, add dimension for broadcasting loss_mask = loss_mask.unsqueeze(-1) reversed_kl_masked = (reversed_kl * loss_mask.squeeze(-1)).sum() / (loss_mask.sum() + 1e-8) # Scale by temperature squared reversed_kl_masked = reversed_kl_masked * (self.temperature ** 2) # Combined loss: alpha * Reversed_KL + (1-alpha) * CE total_loss = self.alpha * reversed_kl_masked + (1 - self.alpha) * ce_loss # Logging if self.state.global_step % self.args.logging_steps == 0: self.log({ "loss/total": total_loss.item(), "loss/reversed_kl": reversed_kl_masked.item(), "loss/cross_entropy": ce_loss.item(), }) return (total_loss, student_outputs) if return_outputs else total_loss # Training arguments training_args = TrainingArguments( output_dir="qwen3-0.6b-telecom-distilled", # Training num_train_epochs=3, per_device_train_batch_size=2, # Increased from 1 (no gradient checkpointing) per_device_eval_batch_size=2, gradient_accumulation_steps=8, # Effective batch size = 16 # Optimization learning_rate=2e-4, lr_scheduler_type="cosine", warmup_ratio=0.1, weight_decay=0.01, # Evaluation eval_strategy="steps", eval_steps=100, save_strategy="steps", save_steps=200, save_total_limit=3, # Logging logging_steps=10, report_to="trackio", run_name="qwen3-0.6b-telecom-minillm", # Memory gradient_checkpointing=False, # Disabled - conflicts with LoRA + dual model distillation bf16=True, # Hub push_to_hub=True, hub_model_id="wlabchoi/qwen3-0.6b-telecom-distilled", hub_strategy="every_save", hub_private_repo=False, # Performance dataloader_num_workers=0, # Avoid multiprocessing issues with tokenizers remove_unused_columns=False, ) # Data collator data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, model=student_model, padding=True, ) # Initialize trainer print("\nInitializing MiniLLM Trainer...") trainer = MiniLLMTrainer( model=student_model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator, teacher_model=teacher_model, temperature=TEMPERATURE, alpha=ALPHA, ) # Start training print("\n" + "="*50) print("Starting MiniLLM Knowledge Distillation...") print(f"✓ Teacher: {TEACHER_MODEL} (frozen)") print(f"✓ Student: {STUDENT_MODEL} (LoRA)") print(f"✓ Method: Reversed KLD + Teacher Sampling") print(f"✓ Temperature: {TEMPERATURE}") print(f"✓ Alpha: {ALPHA}") print(f"✓ Dataset: TeleQnA ({len(train_dataset)} train, {len(eval_dataset)} eval)") print("="*50 + "\n") trainer.train() # Push final model print("\nPushing distilled model to Hub...") trainer.push_to_hub(commit_message="MiniLLM distillation: Qwen3-4B -> Qwen3-0.6B on TeleQnA") print("\n" + "="*50) print("✓ Knowledge Distillation Complete!") print("="*50)