|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
import os |
|
|
import psutil |
|
|
import time |
|
|
import subprocess |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
DEFAULT_TENSORBOARD_PORT = 6006 |
|
|
|
|
|
|
|
|
|
|
|
def get_writer(path: str, tensorboard_port: int | bool, logger: logging.Logger = None): |
|
|
""" |
|
|
Sets up a TensorBoard logging and checkpoint directory for PyTorch. |
|
|
|
|
|
This function clears the specified directory, creates subdirectories for TensorBoard logs |
|
|
and model checkpoints, ensuring a clean environment for running new training sessions. |
|
|
|
|
|
Args: |
|
|
path (str): The root directory where TensorBoard logs and checkpoints will be stored. |
|
|
tensorboard_port (int): The port on which to run the TensorBoard. |
|
|
logger (logging.Logger): The logger that traces the logging information. |
|
|
|
|
|
Returns: |
|
|
tuple: A tuple containing the TensorBoard SummaryWriter object and the path for checkpoints. |
|
|
|
|
|
Example: |
|
|
>>> tensor_writer, checkpoint_dir = get_writer('/path/to/tensorboard/') |
|
|
""" |
|
|
|
|
|
if tensorboard_port is True: |
|
|
tensorboard_port = DEFAULT_TENSORBOARD_PORT |
|
|
elif tensorboard_port is False: |
|
|
return None, os.path.join(path, 'checkpoints') |
|
|
|
|
|
|
|
|
logs_path = os.path.join(path, 'logs') |
|
|
checkpoints_path = os.path.join(path, 'checkpoints') |
|
|
os.makedirs(logs_path, exist_ok=True) |
|
|
os.makedirs(checkpoints_path, exist_ok=True) |
|
|
|
|
|
|
|
|
writer = SummaryWriter(log_dir=logs_path) |
|
|
|
|
|
|
|
|
if logger is not None: |
|
|
logger.info(f"TensorBoard logs will be stored in: {logs_path}") |
|
|
logger.info(f"Model checkpoints will be stored in: {checkpoints_path}") |
|
|
|
|
|
|
|
|
for conn in psutil.net_connections(kind='inet'): |
|
|
if conn.laddr.port == tensorboard_port and conn.status == psutil.CONN_LISTEN: |
|
|
if logger is not None: |
|
|
logger.warning(f"Killing already running TensorBoard process with PID {conn.pid}") |
|
|
p = psutil.Process(conn.pid) |
|
|
p.terminate() |
|
|
p.wait(timeout=3) |
|
|
time.sleep(5) |
|
|
process = subprocess.Popen(f'tensorboard --logdir={logs_path} --host=0.0.0.0 --port={tensorboard_port}', |
|
|
shell=True) |
|
|
if logger is not None: |
|
|
logger.info(f'TensorBoard running at http://0.0.0.0:{tensorboard_port}/ (pid={process.pid})') |
|
|
|
|
|
return writer, checkpoints_path |
|
|
|
|
|
|
|
|
|
|
|
|