|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import atexit |
|
|
import logging |
|
|
import uuid |
|
|
from typing import Any, Dict, Optional, Union |
|
|
|
|
|
import torch |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
from .distributed import get_machine_local_and_dist_rank |
|
|
|
|
|
|
|
|
class TensorBoardLogger: |
|
|
"""A wrapper around TensorBoard SummaryWriter with distributed training support. |
|
|
|
|
|
This logger only writes from rank 0 in distributed settings to avoid conflicts. |
|
|
Automatically handles cleanup on exit. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
path: str, |
|
|
*args: Any, |
|
|
filename_suffix: Optional[str] = None, |
|
|
summary_writer_method: Any = SummaryWriter, |
|
|
**kwargs: Any, |
|
|
) -> None: |
|
|
"""Initialize TensorBoard logger. |
|
|
|
|
|
Args: |
|
|
path: Directory path where TensorBoard logs will be stored |
|
|
filename_suffix: Optional suffix for log filename. If None, uses random UUID |
|
|
summary_writer_method: SummaryWriter class or compatible alternative |
|
|
*args, **kwargs: Additional arguments passed to SummaryWriter |
|
|
""" |
|
|
self._writer: Optional[SummaryWriter] = None |
|
|
_, self._rank = get_machine_local_and_dist_rank() |
|
|
self._path: str = path |
|
|
if self._rank == 0: |
|
|
logging.info( |
|
|
f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}" |
|
|
) |
|
|
self._writer = summary_writer_method( |
|
|
log_dir=path, |
|
|
*args, |
|
|
filename_suffix=filename_suffix or str(uuid.uuid4()), |
|
|
**kwargs, |
|
|
) |
|
|
else: |
|
|
logging.debug( |
|
|
f"Not logging on this process because rank {self._rank} != 0" |
|
|
) |
|
|
|
|
|
atexit.register(self.close) |
|
|
|
|
|
@property |
|
|
def writer(self) -> Optional[SummaryWriter]: |
|
|
"""Get the underlying SummaryWriter instance.""" |
|
|
return self._writer |
|
|
|
|
|
@property |
|
|
def path(self) -> str: |
|
|
"""Get the log directory path.""" |
|
|
return self._path |
|
|
|
|
|
def flush(self) -> None: |
|
|
"""Write pending logs to disk.""" |
|
|
if self._writer: |
|
|
self._writer.flush() |
|
|
|
|
|
def close(self) -> None: |
|
|
"""Close writer and flush pending logs to disk. |
|
|
|
|
|
Logs cannot be written after close() is called. |
|
|
""" |
|
|
if self._writer: |
|
|
self._writer.close() |
|
|
self._writer = None |
|
|
|
|
|
def log_dict(self, payload: Dict[str, Any], step: int) -> None: |
|
|
"""Log multiple scalar values to TensorBoard. |
|
|
|
|
|
Args: |
|
|
payload: Dictionary mapping tag names to scalar values |
|
|
step: Step value to record |
|
|
""" |
|
|
if not self._writer: |
|
|
return |
|
|
|
|
|
for key, value in payload.items(): |
|
|
self.log(key, value, step) |
|
|
|
|
|
def log(self, name: str, data: Any, step: int) -> None: |
|
|
"""Log scalar data to TensorBoard. |
|
|
|
|
|
Args: |
|
|
name: Tag name used to group scalars |
|
|
data: Scalar data to log (float/int/Tensor) |
|
|
step: Step value to record |
|
|
""" |
|
|
if not self._writer: |
|
|
return |
|
|
|
|
|
self._writer.add_scalar(name, data, global_step=step, new_style=True) |
|
|
|
|
|
def log_visuals( |
|
|
self, |
|
|
name: str, |
|
|
data: Union[torch.Tensor, Any], |
|
|
step: int, |
|
|
fps: int = 4 |
|
|
) -> None: |
|
|
"""Log image or video data to TensorBoard. |
|
|
|
|
|
Args: |
|
|
name: Tag name used to group visuals |
|
|
data: Image tensor (3D) or video tensor (5D) |
|
|
step: Step value to record |
|
|
fps: Frames per second for video data |
|
|
|
|
|
Raises: |
|
|
ValueError: If data dimensions are not supported (must be 3D or 5D) |
|
|
""" |
|
|
if not self._writer: |
|
|
return |
|
|
|
|
|
if data.ndim == 3: |
|
|
self._writer.add_image(name, data, global_step=step) |
|
|
elif data.ndim == 5: |
|
|
self._writer.add_video(name, data, global_step=step, fps=fps) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Unsupported data dimensions: {data.ndim}. " |
|
|
"Expected 3D for images or 5D for videos." |
|
|
) |
|
|
|