import os import sys import random import numpy as np from collections import OrderedDict from tabulate import tabulate from pandas import DataFrame from time import gmtime, strftime class Logger: def __init__(self, env_info, fmt=None): self.handler = True self.scalar_metrics = OrderedDict() self.fmt = fmt if fmt else dict() base = './logs' if not os.path.exists(base): os.mkdir(base) self.path = '%s/%s-%s' % (base, env_info['name'], env_info['seed']) self.logs = self.path + '.csv' self.output = self.path + '.out' self.checkpoint = self.path + '.cpt' def prin(*args): str_to_write = ' '.join(map(str, args)) with open(self.output, 'a') as f: f.write(str_to_write + '\n') f.flush() print(str_to_write) sys.stdout.flush() self.print = prin def add_scalar(self, t, key, value): if key not in self.scalar_metrics: self.scalar_metrics[key] = [] self.scalar_metrics[key] += [(t, value)] def add_dict(self, t, d): for key, value in d.iteritems(): self.add_scalar(t, key, value) def add(self, t, **args): for key, value in args.items(): self.add_scalar(t, key, value) def iter_info(self, order=None): names = list(self.scalar_metrics.keys()) if order: names = order values = [self.scalar_metrics[name][-1][1] for name in names] t = int(np.max([self.scalar_metrics[name][-1][0] for name in names])) fmt = ['%s'] + [self.fmt[name] if name in self.fmt else '.1f' for name in names] if self.handler: self.handler = False self.print(tabulate([[t] + values], ['epoch'] + names, floatfmt=fmt)) else: self.print(tabulate([[t] + values], ['epoch'] + names, tablefmt='plain', floatfmt=fmt).split('\n')[1]) def save(self, silent=False): result = None for key in self.scalar_metrics.keys(): if result is None: result = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t') else: df = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t') result = result.join(df, how='outer') result.to_csv(self.logs) if not silent: self.print('The log/output/model have been saved to: ' + self.path + ' + .csv/.out/.cpt')