File size: 990 Bytes
cc0720f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

from lightning.pytorch.callbacks import (
    TQDMProgressBar,
    ModelCheckpoint,
    LearningRateMonitor,
)
from src.layers.utils_training import FreezeClustering

def get_callbacks(args):
    checkpoint_callback = ModelCheckpoint(
                dirpath=args.model_prefix,  # checkpoints_path, # <--- specify this on the trainer itself for version control
                filename="_{epoch}_{step}",
                # every_n_epochs=val_every_n_epochs,
                every_n_train_steps=500,
                save_top_k=-1,  # <--- this is important!
                save_weights_only=True,
            )
    lr_monitor = LearningRateMonitor(logging_interval="epoch")
    callbacks = [
        TQDMProgressBar(refresh_rate=10),
        checkpoint_callback,
        lr_monitor,
    ]
    if args.freeze_clustering:
            callbacks.append(FreezeClustering())
    return callbacks

def get_callbacks_eval(args):
    callbacks=[TQDMProgressBar(refresh_rate=1)]
    return callbacks