import numpy as np from utils1.trainer import Trainer class EarlyStopping: """Early stops the training if validation loss doesn't improve after a given patience.""" def __init__(self, patience=1, verbose=False, delta=0): """ Args: patience (int): How long to wait after last time validation loss improved. Default: 7 verbose (bool): If True, prints a message for each validation loss improvement. Default: False delta (float): Minimum change in the monitored quantity to qualify as an improvement. Default: 0 """ self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.score_max = -np.Inf self.delta = delta def __call__(self, score: float, trainer: Trainer): if self.best_score is None: self.best_score = score self.save_checkpoint(score, trainer) elif score < self.best_score - self.delta: self.counter += 1 print(f"EarlyStopping counter: {self.counter} out of {self.patience}") if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(score, trainer) self.counter = 0 def save_checkpoint(self, score: float, trainer: Trainer): """Saves model when validation loss decrease.""" if self.verbose: print(f"Validation accuracy increased ({self.score_max:.6f} --> {score:.6f}). Saving model ...") trainer.save_networks("best") self.score_max = score