File size: 3,209 Bytes
4c62147 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import logging
# config logging here
format_str = '%(message)s'
formatter = logging.Formatter(format_str)
logging.basicConfig(level=logging.INFO, format=format_str)
logger = logging.getLogger('TorchTask')
# ---------------------------------------------------------------------
# Functions for logging
# ---------------------------------------------------------------------
def log_mode(debug=False):
global logger
if debug:
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.INFO)
def log_file(fpath, debug=False):
global logger
global formatter
fh = logging.FileHandler(fpath)
if debug:
fh.setLevel(logging.DEBUG)
else:
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
logger.addHandler(fh)
def log_info(message):
global logger
out = message
if isinstance(message, list):
out = ''.join(message)
logger.info(out)
def log_warn(message):
global logger
out = message
if isinstance(message, list):
out = ''.join(message)
out = '\n' + '=' * 36 + ' WARN ' + '=' * 36 + '\n' + out + '=' * 78 + '\n'
logger.warn(out)
def log_err(message):
global logger
out = message
if isinstance(message, list):
out = ''.join(message)
out = '\n' + '=' * 35 + ' ERROR ' + '=' * 36 + '\n' + out + '=' * 78 + '\n'
logger.error(out)
exit()
class AvgMeter:
""" Computes and stores the average and current value.
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __format__(self, format):
return "{self.val:{format}} ({self.avg:{format}})".format(
self=self, format=format)
class AvgMeterSet:
def __init__(self):
self.meters = {}
def __getitem__(self, key):
return self.meters[key]
def keys(self):
return self.meters.keys()
def has_key(self, key):
return key in self.meters.keys()
def update(self, name, value, n=1):
if not name in self.meters:
self.meters[name] = AvgMeter()
self.meters[name].update(value, n)
def reset(self, name=None):
if name is None:
for meter in self.meters.values():
meter.reset()
elif name in self.meters.keys():
self.meters[name].reset()
else:
log_err('Unknown key value for AvgMeterSet: {0}\n'.format(name))
def values(self, postfix=''):
return {name + postfix: meter.val for name, meter in self.meters.items()}
def averages(self, postfix='/avg'):
return {name + postfix: meter.avg for name, meter in self.meters.items()}
def sums(self, postfix='/sum'):
return {name + postfix: meter.sum for name, meter in self.meters.items()}
def counts(self, postfix='/count'):
return {name + postfix: meter.count for name, meter in self.meters.items()}
|