์ฝ๋ฐฑ [[callbacks]]
์ฝ๋ฐฑ์ PyTorch [Trainer]์ ๋ฐ๋ณต ํ์ต ๋์์ ์ฌ์ฉ์ ์ ์ํ ์ ์๋ ๊ฐ์ฒด์
๋๋ค
(์ด ๊ธฐ๋ฅ์ TensorFlow์์๋ ์์ง ๊ตฌํ๋์ง ์์์ต๋๋ค). ์ฝ๋ฐฑ์ ๋ฐ๋ณต ํ์ต์ ์ํ๋ฅผ
๊ฒ์ฌํ์ฌ (์งํ ์ํฉ ๋ณด๊ณ , TensorBoard ๋๋ ๊ธฐํ ๋จธ์ ๋ฌ๋ ํ๋ซํผ์ ๋ก๊ทธ ๋จ๊ธฐ๊ธฐ ๋ฑ)
๊ฒฐ์ (์: ์กฐ๊ธฐ ์ข
๋ฃ)์ ๋ด๋ฆด ์ ์์ต๋๋ค.
์ฝ๋ฐฑ์ [TrainerControl] ๊ฐ์ฒด๋ฅผ ๋ฐํํ๋ ๊ฒ ์ธ์๋ ๋ฐ๋ณต ํ์ต์์ ์ด๋ค ๊ฒ๋ ๋ณ๊ฒฝํ ์ ์๋
"์ฝ๊ธฐ ์ ์ฉ" ์ฝ๋ ์กฐ๊ฐ์
๋๋ค. ๋ฐ๋ณต ํ์ต์ ๋ณ๊ฒฝ์ด ํ์ํ ์ฌ์ฉ์ ์ ์ ์์
์ด ํ์ํ ๊ฒฝ์ฐ,
[Trainer]๋ฅผ ์๋ธํด๋์ค๋ก ๋ง๋ค์ด ํ์ํ ๋ฉ์๋๋ค์ ์ค๋ฒ๋ผ์ด๋ํด์ผ ํฉ๋๋ค (์์๋ trainer๋ฅผ ์ฐธ์กฐํ์ธ์).
๊ธฐ๋ณธ์ ์ผ๋ก TrainingArguments.report_to๋ "all"๋ก ์ค์ ๋์ด ์์ผ๋ฏ๋ก, [Trainer]๋ ๋ค์ ์ฝ๋ฐฑ์ ์ฌ์ฉํฉ๋๋ค.
- [
DefaultFlowCallback]๋ ๋ก๊ทธ, ์ ์ฅ, ํ๊ฐ์ ๋ํ ๊ธฐ๋ณธ ๋์์ ์ฒ๋ฆฌํฉ๋๋ค. - [
PrinterCallback] ๋๋ [ProgressCallback]๋ ์งํ ์ํฉ์ ํ์ํ๊ณ ๋ก๊ทธ๋ฅผ ์ถ๋ ฅํฉ๋๋ค ([TrainingArguments]๋ฅผ ํตํด tqdm์ ๋นํ์ฑํํ๋ฉด ์ฒซ ๋ฒ์งธ ์ฝ๋ฐฑ์ด ์ฌ์ฉ๋๊ณ , ๊ทธ๋ ์ง ์์ผ๋ฉด ๋ ๋ฒ์งธ๊ฐ ์ฌ์ฉ๋ฉ๋๋ค). - [
~integrations.TensorBoardCallback]๋ TensorBoard๊ฐ (PyTorch >= 1.4 ๋๋ tensorboardX๋ฅผ ํตํด) ์ ๊ทผ ๊ฐ๋ฅํ๋ฉด ์ฌ์ฉ๋ฉ๋๋ค. - [
~integrations.WandbCallback]๋ wandb๊ฐ ์ค์น๋์ด ์์ผ๋ฉด ์ฌ์ฉ๋ฉ๋๋ค. - [
~integrations.CometCallback]๋ comet_ml์ด ์ค์น๋์ด ์์ผ๋ฉด ์ฌ์ฉ๋ฉ๋๋ค. - [
~integrations.MLflowCallback]๋ mlflow๊ฐ ์ค์น๋์ด ์์ผ๋ฉด ์ฌ์ฉ๋ฉ๋๋ค. - [
~integrations.NeptuneCallback]๋ neptune์ด ์ค์น๋์ด ์์ผ๋ฉด ์ฌ์ฉ๋ฉ๋๋ค. - [
~integrations.AzureMLCallback]๋ azureml-sdk๊ฐ ์ค์น๋์ด ์์ผ๋ฉด ์ฌ์ฉ๋ฉ๋๋ค. - [
~integrations.CodeCarbonCallback]๋ codecarbon์ด ์ค์น๋์ด ์์ผ๋ฉด ์ฌ์ฉ๋ฉ๋๋ค. - [
~integrations.ClearMLCallback]๋ clearml์ด ์ค์น๋์ด ์์ผ๋ฉด ์ฌ์ฉ๋ฉ๋๋ค. - [
~integrations.DagsHubCallback]๋ dagshub์ด ์ค์น๋์ด ์์ผ๋ฉด ์ฌ์ฉ๋ฉ๋๋ค. - [
~integrations.FlyteCallback]๋ flyte๊ฐ ์ค์น๋์ด ์์ผ๋ฉด ์ฌ์ฉ๋ฉ๋๋ค. - [
~integrations.DVCLiveCallback]๋ dvclive๊ฐ ์ค์น๋์ด ์์ผ๋ฉด ์ฌ์ฉ๋ฉ๋๋ค.
ํจํค์ง๊ฐ ์ค์น๋์ด ์์ง๋ง ํด๋น ํตํฉ ๊ธฐ๋ฅ์ ์ฌ์ฉํ๊ณ ์ถ์ง ์๋ค๋ฉด, TrainingArguments.report_to๋ฅผ ์ฌ์ฉํ๊ณ ์ ํ๋ ํตํฉ ๊ธฐ๋ฅ ๋ชฉ๋ก์ผ๋ก ๋ณ๊ฒฝํ ์ ์์ต๋๋ค (์: ["azure_ml", "wandb"]).
์ฝ๋ฐฑ์ ๊ตฌํํ๋ ์ฃผ์ ํด๋์ค๋ [TrainerCallback]์
๋๋ค. ์ด ํด๋์ค๋ [Trainer]๋ฅผ
์ธ์คํด์คํํ๋ ๋ฐ ์ฌ์ฉ๋ [TrainingArguments]๋ฅผ ๊ฐ์ ธ์ค๊ณ , ํด๋น Trainer์ ๋ด๋ถ ์ํ๋ฅผ
[TrainerState]๋ฅผ ํตํด ์ ๊ทผํ ์ ์์ผ๋ฉฐ, [TrainerControl]์ ํตํด ๋ฐ๋ณต ํ์ต์์ ์ผ๋ถ
์์
์ ์ํํ ์ ์์ต๋๋ค.
์ฌ์ฉ ๊ฐ๋ฅํ ์ฝ๋ฐฑ [[available-callbacks]]
๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ ์ฌ์ฉ ๊ฐ๋ฅํ [TrainerCallback] ๋ชฉ๋ก์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
[[autodoc]] integrations.CometCallback - setup
[[autodoc]] DefaultFlowCallback
[[autodoc]] PrinterCallback
[[autodoc]] ProgressCallback
[[autodoc]] EarlyStoppingCallback
[[autodoc]] integrations.TensorBoardCallback
[[autodoc]] integrations.WandbCallback - setup
[[autodoc]] integrations.MLflowCallback - setup
[[autodoc]] integrations.AzureMLCallback
[[autodoc]] integrations.CodeCarbonCallback
[[autodoc]] integrations.NeptuneCallback
[[autodoc]] integrations.ClearMLCallback
[[autodoc]] integrations.DagsHubCallback
[[autodoc]] integrations.FlyteCallback
[[autodoc]] integrations.DVCLiveCallback - setup
TrainerCallback [[trainercallback]]
[[autodoc]] TrainerCallback
์ฌ๊ธฐ PyTorch [Trainer]์ ํจ๊ป ์ฌ์ฉ์ ์ ์ ์ฝ๋ฐฑ์ ๋ฑ๋กํ๋ ์์๊ฐ ์์ต๋๋ค:
class MyCallback(TrainerCallback):
"A callback that prints a message at the beginning of training"
def on_train_begin(self, args, state, control, **kwargs):
print("Starting training")
trainer = Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=[MyCallback], # ์ฐ๋ฆฌ๋ ์ฝ๋ฐฑ ํด๋์ค๋ฅผ ์ด ๋ฐฉ์์ผ๋ก ์ ๋ฌํ๊ฑฐ๋ ๊ทธ๊ฒ์ ์ธ์คํด์ค(MyCallback())๋ฅผ ์ ๋ฌํ ์ ์์ต๋๋ค
)
๋ ๋ค๋ฅธ ์ฝ๋ฐฑ์ ๋ฑ๋กํ๋ ๋ฐฉ๋ฒ์ trainer.add_callback()์ ํธ์ถํ๋ ๊ฒ์
๋๋ค:
trainer = Trainer(...)
trainer.add_callback(MyCallback)
# ๋ค๋ฅธ ๋ฐฉ๋ฒ์ผ๋ก๋ ์ฝ๋ฐฑ ํด๋์ค์ ์ธ์คํด์ค๋ฅผ ์ ๋ฌํ ์ ์์ต๋๋ค
trainer.add_callback(MyCallback())
TrainerState [[trainerstate]]
[[autodoc]] TrainerState
TrainerControl [[trainercontrol]]
[[autodoc]] TrainerControl