File size: 3,102 Bytes
2192664
8ee14ff
 
 
 
 
 
 
2192664
8ee14ff
 
2192664
 
 
8ee14ff
 
 
 
 
 
 
 
 
 
 
 
 
 
af45807
8ee14ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd605d0
8ee14ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd605d0
8ee14ff
 
2192664
 
 
8ee14ff
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# -*- coding: utf-8 -*-
'''
    Various utility functions used (possibly) across scripts.

    2022 Benjamin Kellenberger
'''

import logging
import random
from logging.handlers import TimedRotatingFileHandler

import torch
import yaml
from torch.backends import cudnn

DAYS = 21


def init_logging():
    """
    Setup Python's built in logging functionality with on-disk logging, and prettier logging with Rich
    """
    # Import Rich
    import rich
    from rich.logging import RichHandler
    from rich.style import Style
    from rich.theme import Theme

    name = 'lecture'

    # Setup placeholder for logging handlers
    handlers = []

    # Configuration arguments for console, handlers, and logging
    console_kwargs = {
        'theme': Theme(
            {
                'logging.keyword': Style(bold=True, color='yellow'),
                'logging.level.notset': Style(dim=True),
                'logging.level.debug': Style(color='cyan'),
                'logging.level.info': Style(color='green'),
                'logging.level.warning': Style(color='yellow'),
                'logging.level.error': Style(color='red', bold=True),
                'logging.level.critical': Style(color='red', bold=True, reverse=True),
                'log.time': Style(color='white'),
            }
        )
    }
    handler_kwargs = {
        'rich_tracebacks': True,
        'tracebacks_show_locals': True,
    }
    logging_kwargs = {
        'level': logging.INFO,
        'format': '[%(name)s] %(message)s',
        'datefmt': '[%X]',
    }

    # Add file-baesd log handler
    handlers.append(
        TimedRotatingFileHandler(
            filename=f'{name}.log',
            when='midnight',
            backupCount=DAYS,
        ),
    )

    # Add rich (fancy logging) log handler
    rich.reconfigure(**console_kwargs)
    handlers.append(RichHandler(**handler_kwargs))

    # Setup global logger with the handlers and set the default level to INFO
    logging.basicConfig(handlers=handlers, **logging_kwargs)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log = logging.getLogger(name)

    return log


def init_seed(seed):
    if seed is not None:
        random.seed(seed)
        # numpy.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        cudnn.benchmark = True
        cudnn.deterministic = True


def init_config(config, log):
    # load config
    log.info(f'Using config "{config}"')
    cfg = yaml.safe_load(open(config, 'r'))

    cfg['log'] = log

    # check if GPU is available
    device = cfg.get('device')
    if device not in ['cpu']:
        if torch.cuda.is_available():
            cfg['device'] = 'cuda'
        elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
            cfg['device'] = 'mps'
        else:
            log.warning(
                f'WARNING: device set to "{device}" but not available; falling back to CPU...'
            )
            cfg['device'] = 'cpu'

    device = cfg.get('device')
    log.info(f'Using device "{device}"')

    return cfg