import logging import sys import torch def get_torch_device(): return torch.device("cuda" if torch.cuda.is_available() else "cpu") def register_logger(log_file=None, stdout=True): log = logging.getLogger() # root logger for hdlr in log.handlers[:]: # remove all old handlers log.removeHandler(hdlr) handlers = [] if stdout: handlers.append(logging.StreamHandler(stream=sys.stdout)) if log_file is not None: handlers.append(logging.FileHandler(log_file)) logging.basicConfig(format="%(asctime)s %(message)s", handlers=handlers, level=logging.INFO, ) logging.root.setLevel(logging.INFO)