ethanchern's picture
init
873b6ec
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import datetime
from typing import Callable
import torch
from .logger import print_rank_0
class EventPathTimer:
"""
A lightweight class for recording time without any distributed barrier.
This class allows for recording elapsed time between events without requiring
synchronization across distributed processes. It maintains the previous message
and time to calculate the duration between consecutive records.
"""
def __init__(self):
"""
Initialize the EventPathTimer.
This constructor sets the previous message and time to None, preparing
the instance for recording events.
"""
self.prev_message: str = None
self.prev_time: datetime = None
def reset(self):
"""
Reset the recorded message and time.
This method clears the previous message and time, allowing for a fresh
start in recording new events.
"""
self.prev_message = None
self.prev_time = None
def synced_record(self, message, print_fn: Callable[[str], None] = print_rank_0):
"""
Record the current time with a message.
Args:
message (str): A message to log along with the current time.
This method synchronizes the CUDA operations, records the current time,
and calculates the elapsed time since the last recorded message, if any.
It then logs the elapsed time along with the previous and current messages.
"""
torch.cuda.synchronize()
current_time = datetime.now()
if self.prev_message is not None:
print_fn(
f"\nTime Elapsed: [{current_time - self.prev_time}] From [{self.prev_message} ({self.prev_time})] To [{message} ({current_time})]"
)
self.prev_message = message
self.prev_time = current_time
_GLOBAL_LIGHT_TIMER = EventPathTimer()
def event_path_timer() -> EventPathTimer:
"""Get the current EventPathTimer instance.
Returns:
EventPathTimer: The current EventPathTimer instance.
Raises:
AssertionError: If the EventPathTimer has not been initialized.
"""
assert _GLOBAL_LIGHT_TIMER is not None, "light time recorder is not initialized"
return _GLOBAL_LIGHT_TIMER