alverciito
upload safetensors and refactor research files
dbd79bd
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import logging
import os
import psutil
import time
import subprocess
from torch.utils.tensorboard import SummaryWriter
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
DEFAULT_TENSORBOARD_PORT = 6006
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
def get_writer(path: str, tensorboard_port: int | bool, logger: logging.Logger = None):
"""
Sets up a TensorBoard logging and checkpoint directory for PyTorch.
This function clears the specified directory, creates subdirectories for TensorBoard logs
and model checkpoints, ensuring a clean environment for running new training sessions.
Args:
path (str): The root directory where TensorBoard logs and checkpoints will be stored.
tensorboard_port (int): The port on which to run the TensorBoard.
logger (logging.Logger): The logger that traces the logging information.
Returns:
tuple: A tuple containing the TensorBoard SummaryWriter object and the path for checkpoints.
Example:
>>> tensor_writer, checkpoint_dir = get_writer('/path/to/tensorboard/')
"""
# Check tensorboard port:
if tensorboard_port is True:
tensorboard_port = DEFAULT_TENSORBOARD_PORT
elif tensorboard_port is False:
return None, os.path.join(path, 'checkpoints')
# Create subdirectories for logs and checkpoints
logs_path = os.path.join(path, 'logs')
checkpoints_path = os.path.join(path, 'checkpoints')
os.makedirs(logs_path, exist_ok=True)
os.makedirs(checkpoints_path, exist_ok=True)
# Set up TensorBoard logging
writer = SummaryWriter(log_dir=logs_path)
# Print paths where logs and checkpoints will be stored
if logger is not None:
logger.info(f"TensorBoard logs will be stored in: {logs_path}")
logger.info(f"Model checkpoints will be stored in: {checkpoints_path}")
# Launch tensorboard:
for conn in psutil.net_connections(kind='inet'):
if conn.laddr.port == tensorboard_port and conn.status == psutil.CONN_LISTEN:
if logger is not None:
logger.warning(f"Killing already running TensorBoard process with PID {conn.pid}")
p = psutil.Process(conn.pid)
p.terminate()
p.wait(timeout=3)
time.sleep(5)
process = subprocess.Popen(f'tensorboard --logdir={logs_path} --host=0.0.0.0 --port={tensorboard_port}',
shell=True)
if logger is not None:
logger.info(f'TensorBoard running at http://0.0.0.0:{tensorboard_port}/ (pid={process.pid})')
return writer, checkpoints_path
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #