Spaces:
Configuration error
Configuration error
| # A simple torch style logger | |
| # (C) Wei YANG 2017 | |
| from __future__ import absolute_import | |
| import matplotlib.pyplot as plt | |
| import os | |
| import sys | |
| import numpy as np | |
| __all__ = ['Logger', 'LoggerMonitor', 'savefig'] | |
| def savefig(fname, dpi=None): | |
| dpi = 150 if dpi == None else dpi | |
| plt.savefig(fname, dpi=dpi) | |
| def plot_overlap(logger, names=None): | |
| names = logger.names if names == None else names | |
| numbers = logger.numbers | |
| for _, name in enumerate(names): | |
| x = np.arange(len(numbers[name])) | |
| plt.plot(x, np.asarray(numbers[name])) | |
| return [logger.title + '(' + name + ')' for name in names] | |
| class Logger(object): | |
| '''Save training process to log file with simple plot function.''' | |
| def __init__(self, fpath, title=None, resume=False): | |
| self.file = None | |
| self.resume = resume | |
| self.title = '' if title == None else title | |
| if fpath is not None: | |
| if resume: | |
| self.file = open(fpath, 'r') | |
| name = self.file.readline() | |
| self.names = name.rstrip().split('\t') | |
| self.numbers = {} | |
| for _, name in enumerate(self.names): | |
| self.numbers[name] = [] | |
| for numbers in self.file: | |
| numbers = numbers.rstrip().split('\t') | |
| for i in range(0, len(numbers)): | |
| self.numbers[self.names[i]].append(numbers[i]) | |
| self.file.close() | |
| self.file = open(fpath, 'a') | |
| else: | |
| self.file = open(fpath, 'w') | |
| def set_names(self, names): | |
| if self.resume: | |
| pass | |
| # initialize numbers as empty list | |
| self.numbers = {} | |
| self.names = names | |
| for _, name in enumerate(self.names): | |
| self.file.write(name) | |
| self.file.write('\t') | |
| self.numbers[name] = [] | |
| self.file.write('\n') | |
| self.file.flush() | |
| def append(self, numbers): | |
| assert len(self.names) == len(numbers), 'Numbers do not match names' | |
| for index, num in enumerate(numbers): | |
| self.file.write("{0:.6f}".format(num)) | |
| self.file.write('\t') | |
| self.numbers[self.names[index]].append(num) | |
| self.file.write('\n') | |
| self.file.flush() | |
| def plot(self, names=None): | |
| names = self.names if names == None else names | |
| numbers = self.numbers | |
| for _, name in enumerate(names): | |
| x = np.arange(len(numbers[name])) | |
| plt.plot(x, np.asarray(numbers[name])) | |
| plt.legend([self.title + '(' + name + ')' for name in names]) | |
| plt.grid(True) | |
| def close(self): | |
| if self.file is not None: | |
| self.file.close() | |
| class LoggerMonitor(object): | |
| '''Load and visualize multiple logs.''' | |
| def __init__(self, paths): | |
| '''paths is a distionary with {name:filepath} pair''' | |
| self.loggers = [] | |
| for title, path in paths.items(): | |
| logger = Logger(path, title=title, resume=True) | |
| self.loggers.append(logger) | |
| def plot(self, names=None): | |
| plt.figure() | |
| plt.subplot(121) | |
| legend_text = [] | |
| for logger in self.loggers: | |
| legend_text += plot_overlap(logger, names) | |
| plt.legend(legend_text, bbox_to_anchor=( | |
| 1.05, 1), loc=2, borderaxespad=0.) | |
| plt.grid(True) | |
| if __name__ == '__main__': | |
| # # Example | |
| # logger = Logger('test.txt') | |
| # logger.set_names(['Train loss', 'Valid loss','Test loss']) | |
| # length = 100 | |
| # t = np.arange(length) | |
| # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 | |
| # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 | |
| # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 | |
| # for i in range(0, length): | |
| # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) | |
| # logger.plot() | |
| # Example: logger monitor | |
| paths = { | |
| 'resadvnet20': '/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', | |
| 'resadvnet32': '/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', | |
| 'resadvnet44': '/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', | |
| } | |
| field = ['Valid Acc.'] | |
| monitor = LoggerMonitor(paths) | |
| monitor.plot(names=field) | |
| savefig('test.eps') | |