Spaces:
Running on Zero
Running on Zero
File size: 1,750 Bytes
03022ee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | import logging
logger = logging.getLogger(__name__)
class EpochLogger(object):
def __init__(self, save_file, precision=2):
self.save_file = save_file
self.precision = precision
def item_to_string(self, key, value, prefix=None):
if isinstance(value, float) and 1.0 < value < 100.0:
value = f"{value:.{self.precision}f}"
elif isinstance(value, float):
value = f"{value:.{self.precision}e}"
if prefix is not None:
key = f"{prefix} {key}"
return f"{key}: {value}"
def stats_to_string(self, stats, prefix=None):
return ", ".join(
[self.item_to_string(k, v, prefix) for k, v in stats.items()]
)
def log_stats(
self,
stats_meta,
stats,
stage='train',
verbose=True,
):
string = self.stats_to_string(stats_meta)
if stats is not None:
string += " - " + self.stats_to_string(stats, stage)
with open(self.save_file, "a") as fw:
print(string, file=fw)
if verbose:
logger.info(string)
class EpochCounter(object):
def __init__(self, limit):
self.current = 0
self.limit = limit
def __iter__(self):
return self
def __next__(self):
if self.current < self.limit:
self.current += 1
logger.info(f"Going into epoch {self.current}")
return self.current
raise StopIteration
def save(self, path, device=None):
with open(path, "w") as f:
f.write(str(self.current))
def load(self, path, device=None):
with open(path) as f:
saved_value = int(f.read())
self.current = saved_value
|