Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import numpy as np | |
| import torch | |
| def mkdirs(paths): | |
| if isinstance(paths, list) and not isinstance(paths, str): | |
| for path in paths: | |
| mkdir(path) | |
| else: | |
| mkdir(paths) | |
| def mkdir(path): | |
| if not os.path.exists(path): | |
| os.makedirs(path) | |
| def unnormalize(tens, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): | |
| # assume tensor of shape NxCxHxW | |
| return ( | |
| tens * torch.Tensor(std)[None, :, None, None] | |
| + torch.Tensor(mean)[None, :, None, None] | |
| ) | |
| # 重定向 sys.stdout 输出到一个指定的文件,同时保持在终端显示 | |
| class Logger(object): | |
| """Log stdout messages.""" | |
| def __init__(self, outfile): | |
| self.terminal = sys.stdout # 保存当前标准输出流 | |
| self.log = open(outfile, 'a') # 打开指定的日志文件 outfile,以追加模式 (a) 打开 | |
| # 将标准输出流重定向到 Logger 实例,这样所有的 print 语句和标准输出内容都会通过 Logger 的 write 方法 | |
| sys.stdout = self | |
| def write(self, message): | |
| self.terminal.write(message) # 将消息写到终端(标准输出) | |
| self.log.write(message) # 将消息写到日志文件 | |
| def flush(self): | |
| self.terminal.flush() # 刷新终端的输出缓冲区,确保所有输出都及时显示 | |
| class EarlyStopping: | |
| """Early stops the training if validation loss doesn't improve after a given patience.""" | |
| def __init__(self, patience=7, verbose=False, delta=0): | |
| """ | |
| Args: | |
| patience (int): How long to wait after last time validation loss improved. | |
| Default: 7 | |
| verbose (bool): If True, prints a message for each validation loss improvement. | |
| Default: False | |
| delta (float): Minimum change in the monitored quantity to qualify as an improvement. | |
| Default: 0 | |
| """ | |
| self.patience = patience | |
| self.verbose = verbose | |
| self.counter = 0 | |
| self.best_score = None | |
| self.early_stop = False | |
| self.val_loss_min = np.Inf | |
| self.delta = delta | |
| def __call__(self, val_loss, model): | |
| score = -val_loss | |
| if self.best_score is None: | |
| self.best_score = score | |
| self.save_checkpoint(val_loss, model) | |
| elif score < self.best_score + self.delta: | |
| self.counter += 1 | |
| print(f'EarlyStopping counter: {self.counter} out of {self.patience}') | |
| if self.counter >= self.patience: | |
| self.early_stop = True | |
| else: | |
| self.best_score = score | |
| self.save_checkpoint(val_loss, model) | |
| self.counter = 0 | |
| def save_checkpoint(self, val_loss, model): | |
| """Saves model when validation loss decrease.""" | |
| if self.verbose: | |
| print( | |
| f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...' | |
| ) | |
| self.val_loss_min = val_loss | |
| def printSet(set_str): | |
| set_str = str(set_str) | |
| num = len(set_str) | |
| print('=' * num * 3) | |
| print(' ' * num + set_str) | |
| print('=' * num * 3) | |