Spaces:
Sleeping
Sleeping
File size: 2,147 Bytes
fa64206 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | import torch
import wandb
import yaml
from transformers import Trainer, TrainingArguments
from data.datasets import load_and_tokenize_data
from models.full_finetune_model import get_full_finetune_model
from models.student_model import get_student_model
# Charger la configuration
with open('config/config.yaml', 'r') as f:
config = yaml.safe_load(f)
# Initialiser wandb
wandb.init(project=config['wandb']['project'], entity=config['wandb']['entity'])
# Charger les donn�es
train_dataset, test_dataset = load_and_tokenize_data(config)
# Charger le mod�le teacher et le mod�le student
teacher_model = get_full_finetune_model()
student_model = get_student_model(config)
# D�finir les arguments de formation pour la distillation
training_args = TrainingArguments(
output_dir='./results_student',
num_train_epochs=config['training']['num_epochs'],
per_device_train_batch_size=config['training']['batch_size'],
per_device_eval_batch_size=config['training']['batch_size'],
evaluation_strategy='epoch',
save_steps=10_000,
save_total_limit=2,
logging_dir='./logs',
logging_steps=10,
)
# D�finir le distillateur
class DistillationTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
# Forward pass of teacher model
with torch.no_grad():
teacher_outputs = teacher_model(**inputs)
# Forward pass of student model
student_outputs = model(**inputs)
# Compute distillation loss
loss = torch.nn.functional.kl_div(
torch.nn.functional.log_softmax(student_outputs.logits, dim=-1),
torch.nn.functional.softmax(teacher_outputs.logits, dim=-1),
reduction='batchmean'
)
return (loss, student_outputs) if return_outputs else loss
# Cr�er le Trainer pour la distillation
trainer = DistillationTrainer(
model=student_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
)
# Mesurer les ressources et entra�ner le mod�le student
measure_resources(trainer, "Distillation")
|