Spaces:
Running
Running
| import os | |
| import torch | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| def save_img_and_npy(path, matrix): | |
| plt.imsave(path + ".png", matrix, origin="lower") | |
| def save_checkpoint(state, state_dict_only, path, target): | |
| torch.save(state, os.path.join(path, target + ".chkpnt")) | |
| if state_dict_only: | |
| # save just the weights | |
| torch.save(state["state_dict"], os.path.join(path, target + ".pth")) | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| class EarlyStopping(object): | |
| def __init__(self, mode="min", min_delta=0, patience=10): | |
| self.mode = mode | |
| self.min_delta = min_delta | |
| self.patience = patience | |
| self.best = None | |
| self.num_bad_epochs = 0 | |
| self.is_better = None | |
| self._init_is_better(mode, min_delta) | |
| if patience == 0: | |
| self.is_better = lambda a, b: True | |
| def step(self, metrics): | |
| if self.best is None: | |
| self.best = metrics | |
| return False | |
| if np.isnan(metrics): | |
| return True | |
| if self.is_better(metrics, self.best): | |
| self.num_bad_epochs = 0 | |
| self.best = metrics | |
| else: | |
| self.num_bad_epochs += 1 | |
| if self.num_bad_epochs >= self.patience: | |
| return True | |
| return False | |
| def _init_is_better(self, mode, min_delta): | |
| if mode not in {"min", "max"}: | |
| raise ValueError("mode " + mode + " is unknown!") | |
| if mode == "min": | |
| self.is_better = lambda a, best: a < best - min_delta | |
| if mode == "max": | |
| self.is_better = lambda a, best: a > best + min_delta | |