gen / models /student_model.py
jaothan's picture
Upload 24 files
6f20934 verified
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArgumentsdef train_student_model(train_dataset, val_dataset): """ Train a student model. """ model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased") training_args = TrainingArguments( output_dir='./results', num_train_epochs=3, per_device_train_batch_size=16, per_device_eval_batch_size=16, evaluation_strategy="epoch", save_total_limit=1 )