from typing import List from utils.early_stopping import EarlyStoppingCallback from utils.tensorboard import TensorboardLoggerCallback def setup_callbacks(ckpt_path: str, log_dir: str, early_patience: int, early_monitor: str = "val_loss", early_mode: str = "min", use_tensorboard: bool = False, **kwargs) -> List[object]: delta = kwargs.get("delta", 0) verbose = kwargs.get("verbose", True) restore_best_weights = kwargs.get("restore_best_weights", False) save_final_model = kwargs.get("save_final_model", True) es_call = EarlyStoppingCallback(patience=early_patience, delta=delta, checkpoint_path=ckpt_path, verbose=verbose, restore_best_weights=restore_best_weights, monitor=early_monitor, mode=early_mode, save_final_model=save_final_model ) log_histograms = kwargs.get("log_histograms", False) log_images_every = kwargs.get("log_images_every", 10) tb_port = kwargs.get("tb_port", 6006) tb_browser = kwargs.get("tb_browser", False) tb_logger = TensorboardLoggerCallback(log_dir=log_dir, log_histograms=log_histograms, launch_tb=use_tensorboard, tb_port=tb_port, open_tb_in_browser=tb_browser ) return [es_call, tb_logger]