Trainer
The [Trainer] class provides an API for feature-complete training in PyTorch, and it supports distributed training on multiple GPUs/TPUs, mixed precision for NVIDIA GPUs, AMD GPUs, and torch.amp for PyTorch. [Trainer] goes hand-in-hand with the [TrainingArguments] class, which offers a wide range of options to customize how a model is trained. Together, these two classes provide a complete training API.
[Seq2SeqTrainer] and [Seq2SeqTrainingArguments] inherit from the [Trainer] and [TrainingArguments] classes and they're adapted for training models for sequence-to-sequence tasks such as summarization or translation.
The [Trainer] class is optimized for 🤗 Transformers models and can have surprising behaviors
when used with other models. When using it with your own model, make sure:
- your model always return tuples or subclasses of [
~utils.ModelOutput] - your model can compute the loss if a
labelsargument is provided and that loss is returned as the first element of the tuple (if your model returns tuples) - your model can accept multiple label arguments (use
label_namesin [TrainingArguments] to indicate their name to the [Trainer]) but none of them should be named"label"
Trainer[[api-reference]]
[[autodoc]] Trainer - all
Seq2SeqTrainer
[[autodoc]] Seq2SeqTrainer - evaluate - predict
TrainingArguments
[[autodoc]] TrainingArguments - all
Seq2SeqTrainingArguments
[[autodoc]] Seq2SeqTrainingArguments - all