File size: 2,147 Bytes
fa64206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import wandb
import yaml
from transformers import Trainer, TrainingArguments
from data.datasets import load_and_tokenize_data
from models.full_finetune_model import get_full_finetune_model
from models.student_model import get_student_model

# Charger la configuration
with open('config/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Initialiser wandb
wandb.init(project=config['wandb']['project'], entity=config['wandb']['entity'])

# Charger les donn�es
train_dataset, test_dataset = load_and_tokenize_data(config)

# Charger le mod�le teacher et le mod�le student
teacher_model = get_full_finetune_model()
student_model = get_student_model(config)

# D�finir les arguments de formation pour la distillation
training_args = TrainingArguments(
    output_dir='./results_student',
    num_train_epochs=config['training']['num_epochs'],
    per_device_train_batch_size=config['training']['batch_size'],
    per_device_eval_batch_size=config['training']['batch_size'],
    evaluation_strategy='epoch',
    save_steps=10_000,
    save_total_limit=2,
    logging_dir='./logs',
    logging_steps=10,
)

# D�finir le distillateur
class DistillationTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        # Forward pass of teacher model
        with torch.no_grad():
            teacher_outputs = teacher_model(**inputs)

        # Forward pass of student model
        student_outputs = model(**inputs)

        # Compute distillation loss
        loss = torch.nn.functional.kl_div(
            torch.nn.functional.log_softmax(student_outputs.logits, dim=-1),
            torch.nn.functional.softmax(teacher_outputs.logits, dim=-1),
            reduction='batchmean'
        )
        return (loss, student_outputs) if return_outputs else loss

# Cr�er le Trainer pour la distillation
trainer = DistillationTrainer(
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

# Mesurer les ressources et entra�ner le mod�le student
measure_resources(trainer, "Distillation")