ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, replace
from pathlib import Path
from typing import Callable, Optional, Tuple
import torch
from torch.profiler import ProfilerActivity, profile
from .core import PhysicsNeMoProfilerWrapper, _Profiler_Singleton
@dataclass
class TorchProfilerConfig:
"""
Specific configuration for the pytorch profiler.
Attributes:
name: Name identifier for this profiler configuration
torch_prof_activities: List of PyTorch profiler activities to monitor
record_shapes: Whether to record tensor shapes
profile_memory: Whether to profile memory usage
with_stack: Whether to record stack traces
with_flops: Whether to record FLOPs
schedule: Optional scheduling function for the profiler
on_trace_ready_path: Optional path to save trace files
"""
name: str = "torch"
torch_prof_activities: Optional[Tuple[ProfilerActivity, ...]] = None
record_shapes: bool = True
with_stack: bool = False
profile_memory: bool = True
with_flops: bool = True
schedule: Optional[Callable] = None
on_trace_ready_path: Optional[Path] = None
class TorchProfileWrapper(PhysicsNeMoProfilerWrapper, metaclass=_Profiler_Singleton):
"""Wrapper class for PyTorch profiler functionality.
This class wraps PyTorch's built-in profiler to integrate with PhysicsNeMo's profiling system.
It supports context manager usage for profiling code blocks.
Attributes:
_name: Name identifier for this profiler
_is_context: Whether this profiler supports context manager usage
_is_decorator: Whether this profiler supports decorator usage
"""
_name: str = "torch"
# Overload any of these:
_is_context: bool = True
_is_decorator: bool = False
def __init__(
self, config: Optional[TorchProfilerConfig] = None, **config_overrides
) -> None:
"""Initialize the PyTorch profiler wrapper.
Args:
config: Optional configuration object for the profiler
**config_overrides: Optional keyword arguments to override config values
"""
default_config = TorchProfilerConfig()
# Replace any overrides right into the config:
if config is None:
self._config = replace(default_config, **config_overrides)
else:
self._config = replace(config, **config_overrides)
# Configure pytorch profiler here:
# Set the default profiling activities if not set:
if self._config.torch_prof_activities is None:
torch_prof_activities = [ProfilerActivity.CPU]
if torch.cuda.is_available():
torch_prof_activities.append(ProfilerActivity.CUDA)
self._config.torch_prof_activities = torch_prof_activities
return
def _standup(self) -> None:
"""Initialize the PyTorch profiler with configured settings."""
if self._config.on_trace_ready_path is not None:
on_trace_ready = torch.profiler.tensorboard_trace_handler(
self._config.on_trace_ready_path
)
else:
on_trace_ready = None
self._profiler = profile(
activities=self._config.torch_prof_activities,
profile_memory=self._config.profile_memory,
record_shapes=self._config.record_shapes,
with_stack=self._config.with_stack,
schedule=self._config.schedule,
with_flops=self._config.with_flops,
on_trace_ready=on_trace_ready,
)
self._initialized = True
def finalize(self, output_top: Path) -> None:
"""Finalize profiling and write results to disk.
Args:
output_top: Base output directory path for profiling results
"""
if not self.enabled:
return
# Avoid finalizing if we never initialized or already finalized:
if self.finalized:
return
# Get the output directory:
out_top = self.output_dir(output_top)
if self._profiler is not None and self._profiler.profiler is not None:
try:
averages = self._profiler.key_averages()
except AssertionError:
# no averages recorded!
averages = None
# Write out torch profiling results:
if averages:
with open(out_top / Path("cpu_time.txt"), "w") as cpu_times:
times = averages.table()
cpu_times.write(times)
with open(out_top / Path("gpu_time.txt"), "w") as gpu_times:
times = averages.table(sort_by="cuda_time_total")
gpu_times.write(times)
if self._config.on_trace_ready_path is None:
# Store the trace
trace_path = out_top / Path("trace.json")
self._profiler.export_chrome_trace(str(trace_path))
# Make this profiler completed:
self.finalized = True
def __enter__(self) -> "TorchProfileWrapper":
"""Enter the profiling context.
Returns:
Self reference for context manager usage
"""
self._profiler.__enter__()
return self
def __exit__(
self, *exc: Tuple[Optional[type], Optional[Exception], Optional[str]]
) -> None:
"""Exit the profiling context.
Args:
*exc: Exception information if an error occurred
"""
self._profiler.__exit__(*exc)
def step(self) -> None:
"""Advance the profiler's step counter."""
self._profiler.step()