|
|
|
|
|
from smol import DistillationTrainer |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
from transformers import DistilBERTForSequenceClassification |
|
|
from transformers import AdamW |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
|
|
teacher_model = AutoModel.from_pretrained("swiss-ai/Apertus-8B-Instruct-2509") |
|
|
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-Nemo-Base-2407") |
|
|
|
|
|
|
|
|
|
|
|
student_model = DistilBERTForSequenceClassification.from_pretrained("distilbert-base-uncased") |
|
|
|
|
|
|
|
|
class DistillationLoss(nn.Module): |
|
|
def __init__(self, temperature, alpha): |
|
|
super(DistillationLoss, self).__init__() |
|
|
self.kl_loss = nn.KLDivLoss(temperature=temperature) |
|
|
self.alpha = alpha |
|
|
|
|
|
def forward(self, student_output, teacher_output): |
|
|
return self.kl_loss(student_output.log_softmax(-1), teacher_output.softmax(-1)) * self.alpha |
|
|
|
|
|
|
|
|
def train_step(model, batch, optimizer, loss_fn, device): |
|
|
|
|
|
inputs = tokenizer(batch["input_ids"], **tokenizer_args) |
|
|
labels = batch["labels"] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
teacher_output = model(**inputs) |
|
|
teacher_output = teacher_output.logits if "logits" in teacher_output else teacher_output.logits |
|
|
teacher_output = teacher_output.detach().to(device) |
|
|
|
|
|
|
|
|
student_output = model(**inputs) |
|
|
student_logits = student_output.logits if hasattr(student_output, "logits") else student_output.logits |
|
|
student_logits = student_logits.to(device) |
|
|
|
|
|
|
|
|
distillation_loss = loss_fn(student_logits, teacher_output.softmax(-1)) |
|
|
loss = distillation_loss |
|
|
|
|
|
|
|
|
task_loss = loss_function(student_logits, labels.to(device)) |
|
|
total_loss = distillation_loss + task_loss |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
total_loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
return total_loss.item(), student_output, teacher_output |
|
|
|
|
|
|
|
|
from smol.trainer import DistillationTrainer |
|
|
trainer = DistillationTrainer( |
|
|
student_model, |
|
|
optimizer=AdamW(student_model.parameters(), lr=1e-5), |
|
|
loss_fn=DistillationLoss(temperature=1.0, alpha=0.5), |
|
|
train_dataset=your_train_dataset, |
|
|
eval_dataset=your_eval_dataset, |
|
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
|
num_epochs=5, |
|
|
batch_size=16, |
|
|
log_dir="distillation_logs", |
|
|
) |
|
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
|
|
|
|