Spaces:
Build error
Build error
| import numpy as np | |
| import torch | |
| import os | |
| class EarlyStopping: | |
| """Early stops the training if validation loss doesn't improve after a given patience.""" | |
| def __init__(self, save_path, patience=7, verbose=False, delta=0): | |
| """ | |
| Args: | |
| save_path : 模型保存文件夹 | |
| patience (int): How long to wait after last time validation loss improved. | |
| Default: 7 | |
| verbose (bool): If True, prints a message for each validation loss improvement. | |
| Default: False | |
| delta (float): Minimum change in the monitored quantity to qualify as an improvement. | |
| Default: 0 | |
| """ | |
| self.save_path = save_path | |
| self.patience = patience | |
| self.verbose = verbose | |
| self.counter = 0 | |
| self.best_score = None | |
| self.early_stop = False | |
| self.val_loss_min = np.Inf | |
| self.delta = delta | |
| def __call__(self, val_loss, model): | |
| score = -val_loss | |
| if self.best_score is None: | |
| self.best_score = score | |
| self.save_checkpoint(val_loss, model) | |
| elif score < self.best_score + self.delta: | |
| self.counter += 1 | |
| print(f'EarlyStopping counter: {self.counter} out of {self.patience}') | |
| if self.counter >= self.patience: | |
| self.early_stop = True | |
| else: | |
| self.best_score = score | |
| self.save_checkpoint(val_loss, model) | |
| self.counter = 0 | |
| def save_checkpoint(self, val_loss, model): | |
| '''Saves model when validation loss decrease.''' | |
| if self.verbose: | |
| print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ... best_network.pth ...') | |
| path = os.path.join(self.save_path, 'best_network.pth') | |
| # torch.save(model.state_dict(), path) # 这里会存储迄今最优模型的参数 | |
| self.save_networks(path, model) | |
| self.val_loss_min = val_loss | |
| def save_networks(self, save_path, model): | |
| # serialize model and optimizer to dict | |
| state_dict = { | |
| 'model': model.state_dict(), | |
| } | |
| torch.save(state_dict, save_path) |