Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import random | |
| import logging | |
| from typing import OrderedDict | |
| import torch | |
| import torch.linalg | |
| import numpy as np | |
| import yaml | |
| from easydict import EasyDict | |
| from glob import glob | |
| class BlackHole(object): | |
| def __setattr__(self, name, value): | |
| pass | |
| def __call__(self, *args, **kwargs): | |
| return self | |
| def __getattr__(self, name): | |
| return self | |
| class Counter(object): | |
| def __init__(self, start=0): | |
| super().__init__() | |
| self.now = start | |
| def step(self, delta=1): | |
| prev = self.now | |
| self.now += delta | |
| return prev | |
| def get_logger(name, log_dir=None): | |
| logger = logging.getLogger(name) | |
| logger.setLevel(logging.DEBUG) | |
| formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s') | |
| stream_handler = logging.StreamHandler() | |
| stream_handler.setLevel(logging.DEBUG) | |
| stream_handler.setFormatter(formatter) | |
| logger.addHandler(stream_handler) | |
| if log_dir is not None: | |
| file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt')) | |
| file_handler.setLevel(logging.DEBUG) | |
| file_handler.setFormatter(formatter) | |
| logger.addHandler(file_handler) | |
| return logger | |
| def get_new_log_dir(root='./logs', prefix='', tag=''): | |
| fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime()) | |
| if prefix != '': | |
| fn = prefix + '_' + fn | |
| if tag != '': | |
| fn = fn + '_' + tag | |
| log_dir = os.path.join(root, fn) | |
| os.makedirs(log_dir) | |
| return log_dir | |
| def seed_all(seed): | |
| torch.backends.cudnn.deterministic = True | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| def inf_iterator(iterable): | |
| iterator = iterable.__iter__() | |
| while True: | |
| try: | |
| yield iterator.__next__() | |
| except StopIteration: | |
| iterator = iterable.__iter__() | |
| def log_hyperparams(writer, args): | |
| from torch.utils.tensorboard.summary import hparams | |
| vars_args = {k: v if isinstance(v, str) else repr(v) for k, v in vars(args).items()} | |
| exp, ssi, sei = hparams(vars_args, {}) | |
| writer.file_writer.add_summary(exp) | |
| writer.file_writer.add_summary(ssi) | |
| writer.file_writer.add_summary(sei) | |
| def int_tuple(argstr): | |
| return tuple(map(int, argstr.split(','))) | |
| def str_tuple(argstr): | |
| return tuple(argstr.split(',')) | |
| def get_checkpoint_path(folder, it=None): | |
| if it is not None: | |
| return os.path.join(folder, '%d.pt' % it), it | |
| all_iters = list(map(lambda x: int(os.path.basename(x[:-3])), glob(os.path.join(folder, '*.pt')))) | |
| all_iters.sort() | |
| return os.path.join(folder, '%d.pt' % all_iters[-1]), all_iters[-1] | |
| def load_config(config_path): | |
| with open(config_path, 'r') as f: | |
| config = EasyDict(yaml.safe_load(f)) | |
| config_name = os.path.basename(config_path)[:os.path.basename(config_path).rfind('.')] | |
| return config, config_name | |
| def extract_weights(weights: OrderedDict, prefix): | |
| extracted = OrderedDict() | |
| for k, v in weights.items(): | |
| if k.startswith(prefix): | |
| extracted.update({ | |
| k[len(prefix):]: v | |
| }) | |
| return extracted | |
| def current_milli_time(): | |
| return round(time.time() * 1000) | |