import os import sys import numpy as np import torch def mkdirs(paths): if isinstance(paths, list) and not isinstance(paths, str): for path in paths: mkdir(path) else: mkdir(paths) def mkdir(path): if not os.path.exists(path): os.makedirs(path) def unnormalize(tens, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): # assume tensor of shape NxCxHxW return ( tens * torch.Tensor(std)[None, :, None, None] + torch.Tensor(mean)[None, :, None, None] ) # 重定向 sys.stdout 输出到一个指定的文件,同时保持在终端显示 class Logger(object): """Log stdout messages.""" def __init__(self, outfile): self.terminal = sys.stdout # 保存当前标准输出流 self.log = open(outfile, 'a') # 打开指定的日志文件 outfile,以追加模式 (a) 打开 # 将标准输出流重定向到 Logger 实例,这样所有的 print 语句和标准输出内容都会通过 Logger 的 write 方法 sys.stdout = self def write(self, message): self.terminal.write(message) # 将消息写到终端(标准输出) self.log.write(message) # 将消息写到日志文件 def flush(self): self.terminal.flush() # 刷新终端的输出缓冲区,确保所有输出都及时显示 class EarlyStopping: """Early stops the training if validation loss doesn't improve after a given patience.""" def __init__(self, patience=7, verbose=False, delta=0): """ Args: patience (int): How long to wait after last time validation loss improved. Default: 7 verbose (bool): If True, prints a message for each validation loss improvement. Default: False delta (float): Minimum change in the monitored quantity to qualify as an improvement. Default: 0 """ self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.val_loss_min = np.Inf self.delta = delta def __call__(self, val_loss, model): score = -val_loss if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model) elif score < self.best_score + self.delta: self.counter += 1 print(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(val_loss, model) self.counter = 0 def save_checkpoint(self, val_loss, model): """Saves model when validation loss decrease.""" if self.verbose: print( f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...' ) self.val_loss_min = val_loss def printSet(set_str): set_str = str(set_str) num = len(set_str) print('=' * num * 3) print(' ' * num + set_str) print('=' * num * 3)