""" Early stopping callback for RF-DETR training """ from logging import getLogger logger = getLogger(__name__) class EarlyStoppingCallback: """ Early stopping callback that monitors mAP and stops training if no improvement over a threshold is observed for a specified number of epochs. Args: patience (int): Number of epochs with no improvement to wait before stopping min_delta (float): Minimum change in mAP to qualify as improvement use_ema (bool): Whether to use EMA model metrics for early stopping verbose (bool): Whether to print early stopping messages """ def __init__(self, model, patience=5, min_delta=0.001, use_ema=False, verbose=True): self.patience = patience self.min_delta = min_delta self.use_ema = use_ema self.verbose = verbose self.best_map = 0.0 self.counter = 0 self.model = model def update(self, log_stats): """Update early stopping state based on epoch validation metrics""" regular_map = None ema_map = None if 'test_coco_eval_bbox' in log_stats: regular_map = log_stats['test_coco_eval_bbox'][0] if 'ema_test_coco_eval_bbox' in log_stats: ema_map = log_stats['ema_test_coco_eval_bbox'][0] current_map = None if regular_map is not None and ema_map is not None: if self.use_ema: current_map = ema_map metric_source = "EMA" else: current_map = max(regular_map, ema_map) metric_source = "max(regular, EMA)" elif ema_map is not None: current_map = ema_map metric_source = "EMA" elif regular_map is not None: current_map = regular_map metric_source = "regular" else: if self.verbose: raise ValueError("No valid mAP metric found!") return if self.verbose: print(f"Early stopping: Current mAP ({metric_source}): {current_map:.4f}, Best: {self.best_map:.4f}, Diff: {current_map - self.best_map:.4f}, Min delta: {self.min_delta}") if current_map > self.best_map + self.min_delta: self.best_map = current_map self.counter = 0 logger.info(f"Early stopping: mAP improved to {current_map:.4f} using {metric_source} metric") else: self.counter += 1 if self.verbose: print(f"Early stopping: No improvement in mAP for {self.counter} epochs (best: {self.best_map:.4f}, current: {current_map:.4f})") if self.counter >= self.patience: print(f"Early stopping triggered: No improvement above {self.min_delta} threshold for {self.patience} epochs") if self.model: self.model.request_early_stop()