File size: 6,379 Bytes
c3d0544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, replace
from pathlib import Path
from typing import Callable, Optional, Tuple

import torch
from torch.profiler import ProfilerActivity, profile

from .core import PhysicsNeMoProfilerWrapper, _Profiler_Singleton


@dataclass
class TorchProfilerConfig:
    """
    Specific configuration for the pytorch profiler.

    Attributes:
        name: Name identifier for this profiler configuration
        torch_prof_activities: List of PyTorch profiler activities to monitor
        record_shapes: Whether to record tensor shapes
        profile_memory: Whether to profile memory usage
        with_stack: Whether to record stack traces
        with_flops: Whether to record FLOPs
        schedule: Optional scheduling function for the profiler
        on_trace_ready_path: Optional path to save trace files
    """

    name: str = "torch"
    torch_prof_activities: Optional[Tuple[ProfilerActivity, ...]] = None
    record_shapes: bool = True
    with_stack: bool = False
    profile_memory: bool = True
    with_flops: bool = True
    schedule: Optional[Callable] = None
    on_trace_ready_path: Optional[Path] = None


class TorchProfileWrapper(PhysicsNeMoProfilerWrapper, metaclass=_Profiler_Singleton):
    """Wrapper class for PyTorch profiler functionality.

    This class wraps PyTorch's built-in profiler to integrate with PhysicsNeMo's profiling system.
    It supports context manager usage for profiling code blocks.

    Attributes:
        _name: Name identifier for this profiler
        _is_context: Whether this profiler supports context manager usage
        _is_decorator: Whether this profiler supports decorator usage
    """

    _name: str = "torch"

    # Overload any of these:
    _is_context: bool = True
    _is_decorator: bool = False

    def __init__(
        self, config: Optional[TorchProfilerConfig] = None, **config_overrides
    ) -> None:
        """Initialize the PyTorch profiler wrapper.

        Args:
            config: Optional configuration object for the profiler
            **config_overrides: Optional keyword arguments to override config values
        """
        default_config = TorchProfilerConfig()

        # Replace any overrides right into the config:
        if config is None:
            self._config = replace(default_config, **config_overrides)
        else:
            self._config = replace(config, **config_overrides)

        # Configure pytorch profiler here:
        # Set the default profiling activities if not set:
        if self._config.torch_prof_activities is None:
            torch_prof_activities = [ProfilerActivity.CPU]
            if torch.cuda.is_available():
                torch_prof_activities.append(ProfilerActivity.CUDA)
            self._config.torch_prof_activities = torch_prof_activities

        return

    def _standup(self) -> None:
        """Initialize the PyTorch profiler with configured settings."""
        if self._config.on_trace_ready_path is not None:
            on_trace_ready = torch.profiler.tensorboard_trace_handler(
                self._config.on_trace_ready_path
            )
        else:
            on_trace_ready = None

        self._profiler = profile(
            activities=self._config.torch_prof_activities,
            profile_memory=self._config.profile_memory,
            record_shapes=self._config.record_shapes,
            with_stack=self._config.with_stack,
            schedule=self._config.schedule,
            with_flops=self._config.with_flops,
            on_trace_ready=on_trace_ready,
        )

        self._initialized = True

    def finalize(self, output_top: Path) -> None:
        """Finalize profiling and write results to disk.

        Args:
            output_top: Base output directory path for profiling results
        """
        if not self.enabled:
            return

        # Avoid finalizing if we never initialized or already finalized:
        if self.finalized:
            return

        # Get the output directory:
        out_top = self.output_dir(output_top)
        if self._profiler is not None and self._profiler.profiler is not None:
            try:
                averages = self._profiler.key_averages()
            except AssertionError:
                # no averages recorded!
                averages = None

            # Write out torch profiling results:
            if averages:
                with open(out_top / Path("cpu_time.txt"), "w") as cpu_times:
                    times = averages.table()
                    cpu_times.write(times)

                with open(out_top / Path("gpu_time.txt"), "w") as gpu_times:
                    times = averages.table(sort_by="cuda_time_total")
                    gpu_times.write(times)

            if self._config.on_trace_ready_path is None:
                # Store the trace
                trace_path = out_top / Path("trace.json")
                self._profiler.export_chrome_trace(str(trace_path))

        # Make this profiler completed:
        self.finalized = True

    def __enter__(self) -> "TorchProfileWrapper":
        """Enter the profiling context.

        Returns:
            Self reference for context manager usage
        """
        self._profiler.__enter__()
        return self

    def __exit__(
        self, *exc: Tuple[Optional[type], Optional[Exception], Optional[str]]
    ) -> None:
        """Exit the profiling context.

        Args:
            *exc: Exception information if an error occurred
        """
        self._profiler.__exit__(*exc)

    def step(self) -> None:
        """Advance the profiler's step counter."""
        self._profiler.step()