| import logging | |
| import os | |
| from tensorboardX import SummaryWriter | |
| class TBLogger(object): | |
| def __init__(self, log_path="./logging_data", name="run"): | |
| super(TBLogger, self).__init__() | |
| if not os.path.exists(log_path): | |
| os.makedirs(log_path, exist_ok=True) | |
| self.last_step = 0 | |
| self.log_path = log_path | |
| raw_name = os.path.join(log_path, name) | |
| name = raw_name | |
| for i in range(1000): | |
| name = raw_name + str(f"_{i}") | |
| if not os.path.exists(name): | |
| break | |
| self.writer = SummaryWriter(logdir=name) | |
| def note(self, metrics, step=None): | |
| if step is None: | |
| step = self.last_step | |
| for key, value in metrics.items(): | |
| self.writer.add_scalar(key, value, step) | |
| self.last_step = step | |
| def finish(self): | |
| self.writer.close() | |
| def setup_logging(log_file_path): | |
| logger = logging.getLogger('PretrainLogger') | |
| logger.setLevel(logging.INFO) | |
| formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") | |
| file_handler = logging.FileHandler(log_file_path) | |
| file_handler.setFormatter(formatter) | |
| logger.addHandler(file_handler) | |
| console_handler = logging.StreamHandler() | |
| console_handler.setFormatter(formatter) | |
| logger.addHandler(console_handler) | |
| return logger | |