distill-test / distill.py
Oleg Lavrovsky
Initial testing
7b45378 unverified
# Generated by Apertus on Public AI
from smol import DistillationTrainer
from transformers import AutoModel, AutoTokenizer
from transformers import DistilBERTForSequenceClassification
from transformers import AdamW
import torch
import torch.nn as nn
# Step 1: Load the large model (teacher model)
# Assuming you have a large model (e.g., 8B parameters) and a tokenizer
teacher_model = AutoModel.from_pretrained("swiss-ai/Apertus-8B-Instruct-2509")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-Nemo-Base-2407")
# Step 2: Choose the smaller model (student model)
# Here, we use DistilBERT as an example
student_model = DistilBERTForSequenceClassification.from_pretrained("distilbert-base-uncased")
# Define the distillation loss function (e.g., using KLDivLoss)
class DistillationLoss(nn.Module):
def __init__(self, temperature, alpha):
super(DistillationLoss, self).__init__()
self.kl_loss = nn.KLDivLoss(temperature=temperature)
self.alpha = alpha
def forward(self, student_output, teacher_output):
return self.kl_loss(student_output.log_softmax(-1), teacher_output.softmax(-1)) * self.alpha
# Define a simple training loop
def train_step(model, batch, optimizer, loss_fn, device):
# Preprocess batch
inputs = tokenizer(batch["input_ids"], **tokenizer_args) # Tokenize the input
labels = batch["labels"]
# Forward pass with teacher model
with torch.no_grad():
teacher_output = model(**inputs)
teacher_output = teacher_output.logits if "logits" in teacher_output else teacher_output.logits # Handle model output
teacher_output = teacher_output.detach().to(device)
# Forward pass with student model
student_output = model(**inputs)
student_logits = student_output.logits if hasattr(student_output, "logits") else student_output.logits # Handle model output
student_logits = student_logits.to(device)
# Compute distillation loss
distillation_loss = loss_fn(student_logits, teacher_output.softmax(-1))
loss = distillation_loss
# Compute task loss (e.g., cross-entropy for classification)
task_loss = loss_function(student_logits, labels.to(device)) # Replace with your task-specific loss
total_loss = distillation_loss + task_loss # Combine both losses
# Backward and optimize
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
return total_loss.item(), student_output, teacher_output
# Initialize SMOL's DistillationTrainer
from smol.trainer import DistillationTrainer
trainer = DistillationTrainer(
student_model,
optimizer=AdamW(student_model.parameters(), lr=1e-5), # Example learning rate
loss_fn=DistillationLoss(temperature=1.0, alpha=0.5), # Example distillation loss
train_dataset=your_train_dataset, # Your training dataset
eval_dataset=your_eval_dataset, # Your evaluation dataset
device="cuda" if torch.cuda.is_available() else "cpu", # Use GPU if available
num_epochs=5, # Number of epochs
batch_size=16, # Batch size
log_dir="distillation_logs", # Log directory
)
# Train the model
trainer.train()
# Alternatively, you can use SMOL's simplified training loop (as of SMOL 0.3.0, check the latest docs)
# trainer.train(steps=1000, evaluate_every=100, ...)