# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # # # This file was created by: Alberto Palomo Alonso # # Universidad de Alcalá - Escuela Politécnica Superior # # # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # Import statements: import logging import os import psutil import time import subprocess from torch.utils.tensorboard import SummaryWriter # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # DEFAULT_TENSORBOARD_PORT = 6006 # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # def get_writer(path: str, tensorboard_port: int | bool, logger: logging.Logger = None): """ Sets up a TensorBoard logging and checkpoint directory for PyTorch. This function clears the specified directory, creates subdirectories for TensorBoard logs and model checkpoints, ensuring a clean environment for running new training sessions. Args: path (str): The root directory where TensorBoard logs and checkpoints will be stored. tensorboard_port (int): The port on which to run the TensorBoard. logger (logging.Logger): The logger that traces the logging information. Returns: tuple: A tuple containing the TensorBoard SummaryWriter object and the path for checkpoints. Example: >>> tensor_writer, checkpoint_dir = get_writer('/path/to/tensorboard/') """ # Check tensorboard port: if tensorboard_port is True: tensorboard_port = DEFAULT_TENSORBOARD_PORT elif tensorboard_port is False: return None, os.path.join(path, 'checkpoints') # Create subdirectories for logs and checkpoints logs_path = os.path.join(path, 'logs') checkpoints_path = os.path.join(path, 'checkpoints') os.makedirs(logs_path, exist_ok=True) os.makedirs(checkpoints_path, exist_ok=True) # Set up TensorBoard logging writer = SummaryWriter(log_dir=logs_path) # Print paths where logs and checkpoints will be stored if logger is not None: logger.info(f"TensorBoard logs will be stored in: {logs_path}") logger.info(f"Model checkpoints will be stored in: {checkpoints_path}") # Launch tensorboard: for conn in psutil.net_connections(kind='inet'): if conn.laddr.port == tensorboard_port and conn.status == psutil.CONN_LISTEN: if logger is not None: logger.warning(f"Killing already running TensorBoard process with PID {conn.pid}") p = psutil.Process(conn.pid) p.terminate() p.wait(timeout=3) time.sleep(5) process = subprocess.Popen(f'tensorboard --logdir={logs_path} --host=0.0.0.0 --port={tensorboard_port}', shell=True) if logger is not None: logger.info(f'TensorBoard running at http://0.0.0.0:{tensorboard_port}/ (pid={process.pid})') return writer, checkpoints_path # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # END OF FILE # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #