DrDavis's picture
Upload folder using huggingface_hub
17c6d62 verified

Callbacks

Callbacks可以用来自定义PyTorch [Trainer]中训练循环行为的对象(此功能尚未在TensorFlow中实现),该对象可以检查训练循环状态(用于进度报告、在TensorBoard或其他ML平台上记录日志等),并做出决策(例如提前停止)。

Callbacks是“只读”的代码片段,除了它们返回的[TrainerControl]对象外,它们不能更改训练循环中的任何内容。对于需要更改训练循环的自定义,您应该继承[Trainer]并重载您需要的方法(有关示例,请参见trainer)。

默认情况下,TrainingArguments.report_to 设置为"all",然后[Trainer]将使用以下callbacks。

  • [DefaultFlowCallback],它处理默认的日志记录、保存和评估行为
  • [PrinterCallback] 或 [ProgressCallback],用于显示进度和打印日志(如果通过[TrainingArguments]停用tqdm,则使用第一个函数;否则使用第二个)。
  • [~integrations.TensorBoardCallback],如果TensorBoard可访问(通过PyTorch版本 >= 1.4 或者 tensorboardX)。
  • [~integrations.WandbCallback],如果安装了wandb
  • [~integrations.CometCallback],如果安装了comet_ml
  • [~integrations.MLflowCallback],如果安装了mlflow
  • [~integrations.NeptuneCallback],如果安装了neptune
  • [~integrations.AzureMLCallback],如果安装了azureml-sdk
  • [~integrations.CodeCarbonCallback],如果安装了codecarbon
  • [~integrations.ClearMLCallback],如果安装了clearml
  • [~integrations.DagsHubCallback],如果安装了dagshub
  • [~integrations.FlyteCallback],如果安装了flyte
  • [~integrations.DVCLiveCallback],如果安装了dvclive

如果安装了一个软件包,但您不希望使用相关的集成,您可以将 TrainingArguments.report_to 更改为仅包含您想要使用的集成的列表(例如 ["azure_ml", "wandb"])。

实现callbacks的主要类是[TrainerCallback]。它获取用于实例化[Trainer]的[TrainingArguments],可以通过[TrainerState]访问该Trainer的内部状态,并可以通过[TrainerControl]对训练循环执行一些操作。

可用的Callbacks

这里是库里可用[TrainerCallback]的列表:

[[autodoc]] integrations.CometCallback - setup

[[autodoc]] DefaultFlowCallback

[[autodoc]] PrinterCallback

[[autodoc]] ProgressCallback

[[autodoc]] EarlyStoppingCallback

[[autodoc]] integrations.TensorBoardCallback

[[autodoc]] integrations.WandbCallback - setup

[[autodoc]] integrations.MLflowCallback - setup

[[autodoc]] integrations.AzureMLCallback

[[autodoc]] integrations.CodeCarbonCallback

[[autodoc]] integrations.NeptuneCallback

[[autodoc]] integrations.ClearMLCallback

[[autodoc]] integrations.DagsHubCallback

[[autodoc]] integrations.FlyteCallback

[[autodoc]] integrations.DVCLiveCallback - setup

TrainerCallback

[[autodoc]] TrainerCallback

以下是如何使用PyTorch注册自定义callback的示例:

[Trainer]:

class MyCallback(TrainerCallback):
    "A callback that prints a message at the beginning of training"

    def on_train_begin(self, args, state, control, **kwargs):
        print("Starting training")


trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[MyCallback],  # We can either pass the callback class this way or an instance of it (MyCallback())
)

注册callback的另一种方式是调用 trainer.add_callback(),如下所示:

trainer = Trainer(...)
trainer.add_callback(MyCallback)
# Alternatively, we can pass an instance of the callback class
trainer.add_callback(MyCallback())

TrainerState

[[autodoc]] TrainerState

TrainerControl

[[autodoc]] TrainerControl