| | |
| | |
| | from __future__ import absolute_import |
| | import matplotlib.pyplot as plt |
| | import os |
| | import sys |
| | import numpy as np |
| |
|
| | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] |
| |
|
| | def savefig(fname, dpi=None): |
| | dpi = 150 if dpi == None else dpi |
| | plt.savefig(fname, dpi=dpi) |
| | |
| | def plot_overlap(logger, names=None): |
| | names = logger.names if names == None else names |
| | numbers = logger.numbers |
| | for _, name in enumerate(names): |
| | x = np.arange(len(numbers[name])) |
| | plt.plot(x, np.asarray(numbers[name])) |
| | return [logger.title + '(' + name + ')' for name in names] |
| |
|
| | class Logger(object): |
| | '''Save training process to log file with simple plot function.''' |
| | def __init__(self, fpath, title=None, resume=False): |
| | self.file = None |
| | self.resume = resume |
| | self.title = '' if title == None else title |
| | if fpath is not None: |
| | if resume: |
| | self.file = open(fpath, 'r') |
| | name = self.file.readline() |
| | self.names = name.rstrip().split('\t') |
| | self.numbers = {} |
| | for _, name in enumerate(self.names): |
| | self.numbers[name] = [] |
| |
|
| | for numbers in self.file: |
| | numbers = numbers.rstrip().split('\t') |
| | for i in range(0, len(numbers)): |
| | self.numbers[self.names[i]].append(numbers[i]) |
| | self.file.close() |
| | self.file = open(fpath, 'a') |
| | else: |
| | self.file = open(fpath, 'w') |
| |
|
| | def set_names(self, names): |
| | if self.resume: |
| | pass |
| | |
| | self.numbers = {} |
| | self.names = names |
| | for _, name in enumerate(self.names): |
| | self.file.write(name) |
| | self.file.write('\t') |
| | self.numbers[name] = [] |
| | self.file.write('\n') |
| | self.file.flush() |
| |
|
| |
|
| | def append(self, numbers): |
| | assert len(self.names) == len(numbers), 'Numbers do not match names' |
| | for index, num in enumerate(numbers): |
| | self.file.write("{0:.6f}".format(num)) |
| | self.file.write('\t') |
| | self.numbers[self.names[index]].append(num) |
| | self.file.write('\n') |
| | self.file.flush() |
| |
|
| | def plot(self, names=None): |
| | names = self.names if names == None else names |
| | numbers = self.numbers |
| | for _, name in enumerate(names): |
| | x = np.arange(len(numbers[name])) |
| | plt.plot(x, np.asarray(numbers[name])) |
| | plt.legend([self.title + '(' + name + ')' for name in names]) |
| | plt.grid(True) |
| |
|
| | def close(self): |
| | if self.file is not None: |
| | self.file.close() |
| |
|
| | class LoggerMonitor(object): |
| | '''Load and visualize multiple logs.''' |
| | def __init__ (self, paths): |
| | '''paths is a distionary with {name:filepath} pair''' |
| | self.loggers = [] |
| | for title, path in paths.items(): |
| | logger = Logger(path, title=title, resume=True) |
| | self.loggers.append(logger) |
| |
|
| | def plot(self, names=None): |
| | plt.figure() |
| | plt.subplot(121) |
| | legend_text = [] |
| | for logger in self.loggers: |
| | legend_text += plot_overlap(logger, names) |
| | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) |
| | plt.grid(True) |
| | |
| | if __name__ == '__main__': |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | paths = { |
| | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', |
| | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', |
| | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', |
| | } |
| |
|
| | field = ['Valid Acc.'] |
| |
|
| | monitor = LoggerMonitor(paths) |
| | monitor.plot(names=field) |
| | savefig('test.eps') |