HitPF_demo / src /utils /callbacks.py
github-actions[bot]
Sync from GitHub f6dbbfb
cc0720f
from lightning.pytorch.callbacks import (
TQDMProgressBar,
ModelCheckpoint,
LearningRateMonitor,
)
from src.layers.utils_training import FreezeClustering
def get_callbacks(args):
checkpoint_callback = ModelCheckpoint(
dirpath=args.model_prefix, # checkpoints_path, # <--- specify this on the trainer itself for version control
filename="_{epoch}_{step}",
# every_n_epochs=val_every_n_epochs,
every_n_train_steps=500,
save_top_k=-1, # <--- this is important!
save_weights_only=True,
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
callbacks = [
TQDMProgressBar(refresh_rate=10),
checkpoint_callback,
lr_monitor,
]
if args.freeze_clustering:
callbacks.append(FreezeClustering())
return callbacks
def get_callbacks_eval(args):
callbacks=[TQDMProgressBar(refresh_rate=1)]
return callbacks