| import logging | |
| import sys | |
| import torch | |
| def get_torch_device(): | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def register_logger(log_file=None, stdout=True): | |
| log = logging.getLogger() # root logger | |
| for hdlr in log.handlers[:]: # remove all old handlers | |
| log.removeHandler(hdlr) | |
| handlers = [] | |
| if stdout: | |
| handlers.append(logging.StreamHandler(stream=sys.stdout)) | |
| if log_file is not None: | |
| handlers.append(logging.FileHandler(log_file)) | |
| logging.basicConfig(format="%(asctime)s %(message)s", | |
| handlers=handlers, | |
| level=logging.INFO, | |
| ) | |
| logging.root.setLevel(logging.INFO) | |