Callbacks
Callbacks可以用来自定义PyTorch [Trainer]中训练循环行为的对象(此功能尚未在TensorFlow中实现),该对象可以检查训练循环状态(用于进度报告、在TensorBoard或其他ML平台上记录日志等),并做出决策(例如提前停止)。
Callbacks是“只读”的代码片段,除了它们返回的[TrainerControl]对象外,它们不能更改训练循环中的任何内容。对于需要更改训练循环的自定义,您应该继承[Trainer]并重载您需要的方法(有关示例,请参见trainer)。
默认情况下,TrainingArguments.report_to 设置为"all",然后[Trainer]将使用以下callbacks。
- [
DefaultFlowCallback],它处理默认的日志记录、保存和评估行为 - [
PrinterCallback] 或 [ProgressCallback],用于显示进度和打印日志(如果通过[TrainingArguments]停用tqdm,则使用第一个函数;否则使用第二个)。 - [
~integrations.TensorBoardCallback],如果TensorBoard可访问(通过PyTorch版本 >= 1.4 或者 tensorboardX)。 - [
~integrations.WandbCallback],如果安装了wandb。 - [
~integrations.CometCallback],如果安装了comet_ml。 - [
~integrations.MLflowCallback],如果安装了mlflow。 - [
~integrations.NeptuneCallback],如果安装了neptune。 - [
~integrations.AzureMLCallback],如果安装了azureml-sdk。 - [
~integrations.CodeCarbonCallback],如果安装了codecarbon。 - [
~integrations.ClearMLCallback],如果安装了clearml。 - [
~integrations.DagsHubCallback],如果安装了dagshub。 - [
~integrations.FlyteCallback],如果安装了flyte。 - [
~integrations.DVCLiveCallback],如果安装了dvclive。
如果安装了一个软件包,但您不希望使用相关的集成,您可以将 TrainingArguments.report_to 更改为仅包含您想要使用的集成的列表(例如 ["azure_ml", "wandb"])。
实现callbacks的主要类是[TrainerCallback]。它获取用于实例化[Trainer]的[TrainingArguments],可以通过[TrainerState]访问该Trainer的内部状态,并可以通过[TrainerControl]对训练循环执行一些操作。
可用的Callbacks
这里是库里可用[TrainerCallback]的列表:
[[autodoc]] integrations.CometCallback - setup
[[autodoc]] DefaultFlowCallback
[[autodoc]] PrinterCallback
[[autodoc]] ProgressCallback
[[autodoc]] EarlyStoppingCallback
[[autodoc]] integrations.TensorBoardCallback
[[autodoc]] integrations.WandbCallback - setup
[[autodoc]] integrations.MLflowCallback - setup
[[autodoc]] integrations.AzureMLCallback
[[autodoc]] integrations.CodeCarbonCallback
[[autodoc]] integrations.NeptuneCallback
[[autodoc]] integrations.ClearMLCallback
[[autodoc]] integrations.DagsHubCallback
[[autodoc]] integrations.FlyteCallback
[[autodoc]] integrations.DVCLiveCallback - setup
TrainerCallback
[[autodoc]] TrainerCallback
以下是如何使用PyTorch注册自定义callback的示例:
[Trainer]:
class MyCallback(TrainerCallback):
"A callback that prints a message at the beginning of training"
def on_train_begin(self, args, state, control, **kwargs):
print("Starting training")
trainer = Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=[MyCallback], # We can either pass the callback class this way or an instance of it (MyCallback())
)
注册callback的另一种方式是调用 trainer.add_callback(),如下所示:
trainer = Trainer(...)
trainer.add_callback(MyCallback)
# Alternatively, we can pass an instance of the callback class
trainer.add_callback(MyCallback())
TrainerState
[[autodoc]] TrainerState
TrainerControl
[[autodoc]] TrainerControl