bluemellophone's picture
Fix image
cd605d0 unverified
# -*- coding: utf-8 -*-
'''
Various utility functions used (possibly) across scripts.
2022 Benjamin Kellenberger
'''
import logging
import random
from logging.handlers import TimedRotatingFileHandler
import torch
import yaml
from torch.backends import cudnn
DAYS = 21
def init_logging():
"""
Setup Python's built in logging functionality with on-disk logging, and prettier logging with Rich
"""
# Import Rich
import rich
from rich.logging import RichHandler
from rich.style import Style
from rich.theme import Theme
name = 'lecture'
# Setup placeholder for logging handlers
handlers = []
# Configuration arguments for console, handlers, and logging
console_kwargs = {
'theme': Theme(
{
'logging.keyword': Style(bold=True, color='yellow'),
'logging.level.notset': Style(dim=True),
'logging.level.debug': Style(color='cyan'),
'logging.level.info': Style(color='green'),
'logging.level.warning': Style(color='yellow'),
'logging.level.error': Style(color='red', bold=True),
'logging.level.critical': Style(color='red', bold=True, reverse=True),
'log.time': Style(color='white'),
}
)
}
handler_kwargs = {
'rich_tracebacks': True,
'tracebacks_show_locals': True,
}
logging_kwargs = {
'level': logging.INFO,
'format': '[%(name)s] %(message)s',
'datefmt': '[%X]',
}
# Add file-baesd log handler
handlers.append(
TimedRotatingFileHandler(
filename=f'{name}.log',
when='midnight',
backupCount=DAYS,
),
)
# Add rich (fancy logging) log handler
rich.reconfigure(**console_kwargs)
handlers.append(RichHandler(**handler_kwargs))
# Setup global logger with the handlers and set the default level to INFO
logging.basicConfig(handlers=handlers, **logging_kwargs)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
log = logging.getLogger(name)
return log
def init_seed(seed):
if seed is not None:
random.seed(seed)
# numpy.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
cudnn.benchmark = True
cudnn.deterministic = True
def init_config(config, log):
# load config
log.info(f'Using config "{config}"')
cfg = yaml.safe_load(open(config, 'r'))
cfg['log'] = log
# check if GPU is available
device = cfg.get('device')
if device not in ['cpu']:
if torch.cuda.is_available():
cfg['device'] = 'cuda'
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
cfg['device'] = 'mps'
else:
log.warning(
f'WARNING: device set to "{device}" but not available; falling back to CPU...'
)
cfg['device'] = 'cpu'
device = cfg.get('device')
log.info(f'Using device "{device}"')
return cfg