File size: 3,201 Bytes
dbd79bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
#                                                           #
#   This file was created by: Alberto Palomo Alonso         #
# Universidad de Alcalá - Escuela Politécnica Superior      #
#                                                           #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import logging
import os
import psutil
import time
import subprocess
from torch.utils.tensorboard import SummaryWriter
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
DEFAULT_TENSORBOARD_PORT = 6006
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #


def get_writer(path: str, tensorboard_port: int | bool, logger: logging.Logger = None):
    """
    Sets up a TensorBoard logging and checkpoint directory for PyTorch.

    This function clears the specified directory, creates subdirectories for TensorBoard logs
    and model checkpoints, ensuring a clean environment for running new training sessions.

    Args:
        path (str): The root directory where TensorBoard logs and checkpoints will be stored.
        tensorboard_port (int): The port on which to run the TensorBoard.
        logger (logging.Logger): The logger that traces the logging information.

    Returns:
        tuple: A tuple containing the TensorBoard SummaryWriter object and the path for checkpoints.

    Example:
        >>> tensor_writer, checkpoint_dir = get_writer('/path/to/tensorboard/')
    """
    # Check tensorboard port:
    if tensorboard_port is True:
        tensorboard_port = DEFAULT_TENSORBOARD_PORT
    elif tensorboard_port is False:
        return None, os.path.join(path, 'checkpoints')

    # Create subdirectories for logs and checkpoints
    logs_path = os.path.join(path, 'logs')
    checkpoints_path = os.path.join(path, 'checkpoints')
    os.makedirs(logs_path, exist_ok=True)
    os.makedirs(checkpoints_path, exist_ok=True)

    # Set up TensorBoard logging
    writer = SummaryWriter(log_dir=logs_path)

    # Print paths where logs and checkpoints will be stored
    if logger is not None:
        logger.info(f"TensorBoard logs will be stored in: {logs_path}")
        logger.info(f"Model checkpoints will be stored in: {checkpoints_path}")

    # Launch tensorboard:
    for conn in psutil.net_connections(kind='inet'):
        if conn.laddr.port == tensorboard_port and conn.status == psutil.CONN_LISTEN:
            if logger is not None:
                logger.warning(f"Killing already running TensorBoard process with PID {conn.pid}")
            p = psutil.Process(conn.pid)
            p.terminate()
            p.wait(timeout=3)
            time.sleep(5)
    process = subprocess.Popen(f'tensorboard --logdir={logs_path} --host=0.0.0.0 --port={tensorboard_port}',
                               shell=True)
    if logger is not None:
        logger.info(f'TensorBoard running at http://0.0.0.0:{tensorboard_port}/ (pid={process.pid})')

    return writer, checkpoints_path
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
#                        END OF FILE                        #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #