Spaces:
Sleeping
Sleeping
File size: 2,435 Bytes
37163a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import os
import shutil
import time
import logging
from torch.optim import SGD, Adam, AdamW
def ensure_path(path, replace=True, force_replace=False):
is_temp = os.path.basename(path.rstrip('/')).startswith('_')
if os.path.exists(path):
if replace and (is_temp or force_replace or input(f'{path} exists, replace? y/[n] ') == 'y'):
shutil.rmtree(path)
os.mkdir(path)
else:
os.makedirs(path)
def set_logger(file_path):
logger = logging.getLogger()
logger.setLevel('INFO')
stream_handler = logging.StreamHandler()
file_handler = logging.FileHandler(file_path, 'a')
formatter = logging.Formatter('[%(asctime)s] %(message)s', '%m-%d %H:%M:%S')
for handler in [stream_handler, file_handler]:
handler.setFormatter(formatter)
handler.setLevel('INFO')
logger.addHandler(handler)
return logger
def compute_num_params(model, text=True):
tot = sum(p.numel() for p in model.parameters())
if text:
if tot >= 1e6:
s = '{:.1f}M'.format(tot / 1e6)
else:
s = '{:.1f}K'.format(tot / 1e3)
return f'{s} ({tot})'
else:
return tot
def make_optimizer(params, optimizer_spec):
optimizer = {
'sgd': SGD,
'adam': Adam,
'adamw': AdamW,
}[optimizer_spec['name']](params, **optimizer_spec['args'])
return optimizer
class Averager():
def __init__(self, v=None):
if v is None:
self.n = 0.
self.v = 0.
else:
self.n = 1.
self.v = v
def add(self, v, n=1.0):
self.v = self.v * (self.n / (self.n + n)) + v * (n / (self.n + n))
self.n += n
def item(self):
return self.v
class EpochTimer():
def __init__(self, max_epoch):
self.max_epoch = max_epoch
self.epoch = 0
self.t_start = time.time()
self.t_last = self.t_start
def epoch_done(self):
t_cur = time.time()
self.epoch += 1
epoch_time = t_cur - self.t_last
tot_time = t_cur - self.t_start
est_time = tot_time / self.epoch * self.max_epoch
self.t_last = t_cur
return time_text(epoch_time), time_text(tot_time), time_text(est_time)
def time_text(sec):
if sec >= 3600:
return f'{sec / 3600:.1f}h'
elif sec >= 60:
return f'{sec / 60:.1f}m'
else:
return f'{sec:.1f}s'
|