File size: 4,299 Bytes
66003a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import atexit
import logging
import uuid
from typing import Any, Dict, Optional, Union

import torch
from torch.utils.tensorboard import SummaryWriter

from .distributed import get_machine_local_and_dist_rank


class TensorBoardLogger:
    """A wrapper around TensorBoard SummaryWriter with distributed training support.

    This logger only writes from rank 0 in distributed settings to avoid conflicts.
    Automatically handles cleanup on exit.
    """

    def __init__(
        self,
        path: str,
        *args: Any,
        filename_suffix: Optional[str] = None,
        summary_writer_method: Any = SummaryWriter,
        **kwargs: Any,
    ) -> None:
        """Initialize TensorBoard logger.

        Args:
            path: Directory path where TensorBoard logs will be stored
            filename_suffix: Optional suffix for log filename. If None, uses random UUID
            summary_writer_method: SummaryWriter class or compatible alternative
            *args, **kwargs: Additional arguments passed to SummaryWriter
        """
        self._writer: Optional[SummaryWriter] = None
        _, self._rank = get_machine_local_and_dist_rank()
        self._path: str = path
        if self._rank == 0:
            logging.info(
                f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}"
            )
            self._writer = summary_writer_method(
                log_dir=path,
                *args,
                filename_suffix=filename_suffix or str(uuid.uuid4()),
                **kwargs,
            )
        else:
            logging.debug(
                f"Not logging on this process because rank {self._rank} != 0"
            )

        atexit.register(self.close)

    @property
    def writer(self) -> Optional[SummaryWriter]:
        """Get the underlying SummaryWriter instance."""
        return self._writer

    @property
    def path(self) -> str:
        """Get the log directory path."""
        return self._path

    def flush(self) -> None:
        """Write pending logs to disk."""
        if self._writer:
            self._writer.flush()

    def close(self) -> None:
        """Close writer and flush pending logs to disk.

        Logs cannot be written after close() is called.
        """
        if self._writer:
            self._writer.close()
            self._writer = None

    def log_dict(self, payload: Dict[str, Any], step: int) -> None:
        """Log multiple scalar values to TensorBoard.

        Args:
            payload: Dictionary mapping tag names to scalar values
            step: Step value to record
        """
        if not self._writer:
            return

        for key, value in payload.items():
            self.log(key, value, step)

    def log(self, name: str, data: Any, step: int) -> None:
        """Log scalar data to TensorBoard.

        Args:
            name: Tag name used to group scalars
            data: Scalar data to log (float/int/Tensor)
            step: Step value to record
        """
        if not self._writer:
            return

        self._writer.add_scalar(name, data, global_step=step, new_style=True)

    def log_visuals(
        self,
        name: str,
        data: Union[torch.Tensor, Any],
        step: int,
        fps: int = 4
    ) -> None:
        """Log image or video data to TensorBoard.

        Args:
            name: Tag name used to group visuals
            data: Image tensor (3D) or video tensor (5D)
            step: Step value to record
            fps: Frames per second for video data

        Raises:
            ValueError: If data dimensions are not supported (must be 3D or 5D)
        """
        if not self._writer:
            return

        if data.ndim == 3:
            self._writer.add_image(name, data, global_step=step)
        elif data.ndim == 5:
            self._writer.add_video(name, data, global_step=step, fps=fps)
        else:
            raise ValueError(
                f"Unsupported data dimensions: {data.ndim}. "
                "Expected 3D for images or 5D for videos."
            )