Spaces:
Build error
Build error
| r""" Logging """ | |
| import datetime | |
| import logging | |
| import os | |
| from tensorboardX import SummaryWriter | |
| import torch | |
| class Logger: | |
| r""" Writes results of training/testing """ | |
| def initialize(cls, args, training): | |
| logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S') | |
| logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-1].split('.')[0] + logtime | |
| if logpath == '': logpath = logtime | |
| cls.logpath = os.path.join('logs', logpath + '.log') | |
| cls.benchmark = args.benchmark | |
| os.makedirs(cls.logpath) | |
| logging.basicConfig(filemode='w', | |
| filename=os.path.join(cls.logpath, 'log.txt'), | |
| level=logging.INFO, | |
| format='%(message)s', | |
| datefmt='%m-%d %H:%M:%S') | |
| # Console log config | |
| console = logging.StreamHandler() | |
| console.setLevel(logging.INFO) | |
| formatter = logging.Formatter('%(message)s') | |
| console.setFormatter(formatter) | |
| logging.getLogger('').addHandler(console) | |
| # Tensorboard writer | |
| cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs')) | |
| # Log arguments | |
| if training: | |
| logging.info(':======== Convolutional Hough Matching Networks =========') | |
| for arg_key in args.__dict__: | |
| logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key]))) | |
| logging.info(':========================================================\n') | |
| def info(cls, msg): | |
| r""" Writes message to .txt """ | |
| logging.info(msg) | |
| def save_model(cls, model, epoch, val_pck): | |
| torch.save(model.state_dict(), os.path.join(cls.logpath, 'pck_best_model.pt')) | |
| cls.info('Model saved @%d w/ val. PCK: %5.2f.\n' % (epoch, val_pck)) | |
| class AverageMeter: | |
| r""" Stores loss, evaluation results, selected layers """ | |
| def __init__(self, benchamrk): | |
| r""" Constructor of AverageMeter """ | |
| self.buffer_keys = ['pck'] | |
| self.buffer = {} | |
| for key in self.buffer_keys: | |
| self.buffer[key] = [] | |
| self.loss_buffer = [] | |
| def update(self, eval_result, loss=None): | |
| for key in self.buffer_keys: | |
| self.buffer[key] += eval_result[key] | |
| if loss is not None: | |
| self.loss_buffer.append(loss) | |
| def write_result(self, split, epoch): | |
| msg = '\n*** %s ' % split | |
| msg += '[@Epoch %02d] ' % epoch | |
| if len(self.loss_buffer) > 0: | |
| msg += 'Loss: %5.2f ' % (sum(self.loss_buffer) / len(self.loss_buffer)) | |
| for key in self.buffer_keys: | |
| msg += '%s: %6.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key])) | |
| msg += '***\n' | |
| Logger.info(msg) | |
| def write_process(self, batch_idx, datalen, epoch): | |
| msg = '[Epoch: %02d] ' % epoch | |
| msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen) | |
| if len(self.loss_buffer) > 0: | |
| msg += 'Loss: %5.2f ' % self.loss_buffer[-1] | |
| msg += 'Avg Loss: %5.5f ' % (sum(self.loss_buffer) / len(self.loss_buffer)) | |
| for key in self.buffer_keys: | |
| msg += 'Avg %s: %5.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]) * 100) | |
| Logger.info(msg) | |
| def write_test_process(self, batch_idx, datalen): | |
| msg = '[Batch: %04d/%04d] ' % (batch_idx+1, datalen) | |
| for key in self.buffer_keys: | |
| if key == 'pck': | |
| pcks = torch.stack(self.buffer[key]).mean(dim=0) * 100 | |
| val = '' | |
| for p in pcks: | |
| val += '%5.2f ' % p.item() | |
| msg += 'Avg %s: %s ' % (key.upper(), val) | |
| else: | |
| msg += 'Avg %s: %5.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key])) | |
| Logger.info(msg) | |
| def get_test_result(self): | |
| result = {} | |
| for key in self.buffer_keys: | |
| result[key] = torch.stack(self.buffer[key]).mean(dim=0) * 100 | |
| return result | |