import logging import os from tensorboardX import SummaryWriter class TBLogger(object): def __init__(self, log_path="./logging_data", name="run"): super(TBLogger, self).__init__() if not os.path.exists(log_path): os.makedirs(log_path, exist_ok=True) self.last_step = 0 self.log_path = log_path raw_name = os.path.join(log_path, name) name = raw_name for i in range(1000): name = raw_name + str(f"_{i}") if not os.path.exists(name): break self.writer = SummaryWriter(logdir=name) def note(self, metrics, step=None): if step is None: step = self.last_step for key, value in metrics.items(): self.writer.add_scalar(key, value, step) self.last_step = step def finish(self): self.writer.close() def setup_logging(log_file_path): logger = logging.getLogger('PretrainLogger') logger.setLevel(logging.INFO) formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") file_handler = logging.FileHandler(log_file_path) file_handler.setFormatter(formatter) logger.addHandler(file_handler) console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logger.addHandler(console_handler) return logger