File size: 3,301 Bytes
7b45378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Generated by Apertus on Public AI
from smol import DistillationTrainer
from transformers import AutoModel, AutoTokenizer
from transformers import DistilBERTForSequenceClassification
from transformers import AdamW
import torch
import torch.nn as nn

# Step 1: Load the large model (teacher model)
# Assuming you have a large model (e.g., 8B parameters) and a tokenizer
teacher_model = AutoModel.from_pretrained("swiss-ai/Apertus-8B-Instruct-2509")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-Nemo-Base-2407")

# Step 2: Choose the smaller model (student model)
# Here, we use DistilBERT as an example
student_model = DistilBERTForSequenceClassification.from_pretrained("distilbert-base-uncased")

# Define the distillation loss function (e.g., using KLDivLoss)
class DistillationLoss(nn.Module):
    def __init__(self, temperature, alpha):
        super(DistillationLoss, self).__init__()
        self.kl_loss = nn.KLDivLoss(temperature=temperature)
        self.alpha = alpha

    def forward(self, student_output, teacher_output):
        return self.kl_loss(student_output.log_softmax(-1), teacher_output.softmax(-1)) * self.alpha

# Define a simple training loop
def train_step(model, batch, optimizer, loss_fn, device):
    # Preprocess batch
    inputs = tokenizer(batch["input_ids"], **tokenizer_args)  # Tokenize the input
    labels = batch["labels"]
    
    # Forward pass with teacher model
    with torch.no_grad():
        teacher_output = model(**inputs)
        teacher_output = teacher_output.logits if "logits" in teacher_output else teacher_output.logits  # Handle model output
        teacher_output = teacher_output.detach().to(device)

    # Forward pass with student model
    student_output = model(**inputs)
    student_logits = student_output.logits if hasattr(student_output, "logits") else student_output.logits  # Handle model output
    student_logits = student_logits.to(device)

    # Compute distillation loss
    distillation_loss = loss_fn(student_logits, teacher_output.softmax(-1))
    loss = distillation_loss

    # Compute task loss (e.g., cross-entropy for classification)
    task_loss = loss_function(student_logits, labels.to(device))  # Replace with your task-specific loss
    total_loss = distillation_loss + task_loss  # Combine both losses

    # Backward and optimize
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    return total_loss.item(), student_output, teacher_output

# Initialize SMOL's DistillationTrainer
from smol.trainer import DistillationTrainer
trainer = DistillationTrainer(
    student_model,
    optimizer=AdamW(student_model.parameters(), lr=1e-5),  # Example learning rate
    loss_fn=DistillationLoss(temperature=1.0, alpha=0.5),  # Example distillation loss
    train_dataset=your_train_dataset,  # Your training dataset
    eval_dataset=your_eval_dataset,  # Your evaluation dataset
    device="cuda" if torch.cuda.is_available() else "cpu",  # Use GPU if available
    num_epochs=5,  # Number of epochs
    batch_size=16,  # Batch size
    log_dir="distillation_logs",  # Log directory
)

# Train the model
trainer.train()

# Alternatively, you can use SMOL's simplified training loop (as of SMOL 0.3.0, check the latest docs)
# trainer.train(steps=1000, evaluate_every=100, ...)