xuan3986's picture
Upload 111 files
03022ee verified
import logging
logger = logging.getLogger(__name__)
class EpochLogger(object):
def __init__(self, save_file, precision=2):
self.save_file = save_file
self.precision = precision
def item_to_string(self, key, value, prefix=None):
if isinstance(value, float) and 1.0 < value < 100.0:
value = f"{value:.{self.precision}f}"
elif isinstance(value, float):
value = f"{value:.{self.precision}e}"
if prefix is not None:
key = f"{prefix} {key}"
return f"{key}: {value}"
def stats_to_string(self, stats, prefix=None):
return ", ".join(
[self.item_to_string(k, v, prefix) for k, v in stats.items()]
)
def log_stats(
self,
stats_meta,
stats,
stage='train',
verbose=True,
):
string = self.stats_to_string(stats_meta)
if stats is not None:
string += " - " + self.stats_to_string(stats, stage)
with open(self.save_file, "a") as fw:
print(string, file=fw)
if verbose:
logger.info(string)
class EpochCounter(object):
def __init__(self, limit):
self.current = 0
self.limit = limit
def __iter__(self):
return self
def __next__(self):
if self.current < self.limit:
self.current += 1
logger.info(f"Going into epoch {self.current}")
return self.current
raise StopIteration
def save(self, path, device=None):
with open(path, "w") as f:
f.write(str(self.current))
def load(self, path, device=None):
with open(path) as f:
saved_value = int(f.read())
self.current = saved_value