Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Optional, Union
import lightning.pytorch as pl
import torch
import torch.cuda
from torch import distributed
def reduce_value(
value: Union[int, float],
reduce_op: str = 'mean',
):
"""
Reduce a value across distributed processes.
Args:
value (Union[int, float]): The value to reduce.
model_device (torch.device): The device on which the model is located.
reduce_op (str, optional): The reduction operation to perform. One of 'mean', 'avg', 'sum', 'min', 'max'.
Defaults to 'mean'.
"""
tensor_value = torch.tensor(value)
if reduce_op in ['mean', 'avg', 'sum']:
op = distributed.ReduceOp.SUM
elif reduce_op == 'min':
op = distributed.ReduceOp.MIN
elif reduce_op == 'max':
op = distributed.ReduceOp.MAX
else:
raise ValueError(f'{reduce_op=} not supported.')
distributed.all_reduce(tensor_value, op=op)
if reduce_op in ['mean', 'avg']:
tensor_value = tensor_value / distributed.get_world_size()
return tensor_value.item()
class MemoryMonitor(pl.Callback):
"""
Logs the memory usage of the model.
This callback calls the torch memory stats API for CUDA and reports different memory statistics.
Example:
import nemo_run as run
from nemo.lightning.pytorch.callbacks import MemoryMonitor
recipe.trainer.callbacks.append(
run.Config(MemoryMonitor)
)
The memory statistics are logged by the :class:`.Logger` to the following keys as
described below.
+--------------------------+-------------------------------------------------------------+
| Key | Logged data |
+==========================+=============================================================+
| | Several memory usage statistics |
| ``memory/{statistic}`` | are logged on |
| | :attr:`.Event.AFTER_TRAIN_BATCH` event. |
+--------------------------+-------------------------------------------------------------+
The following statistics are recorded:
+------------------------+----------------------------------------------------------------------------------------+
| Statistic | Description |
+========================+========================================================================================+
| current_allocated_mem | Current amount of allocated memory in gigabytes. |
+------------------------+----------------------------------------------------------------------------------------+
| current_active_mem | Current amount of active memory in gigabytes at the time of recording. |
+------------------------+----------------------------------------------------------------------------------------+
| current_inactive_mem | Current amount of inactive, non-releaseable memory in gigabytes. |
+------------------------+----------------------------------------------------------------------------------------+
| current_reserved_mem | Current amount of reserved memory in gigabytes at the time of recording. |
+------------------------+----------------------------------------------------------------------------------------+
| peak_allocated_mem | Peak amount of allocated memory in gigabytes. |
+------------------------+----------------------------------------------------------------------------------------+
| peak_active_mem | Peak amount of active memory in gigabytes at the time of recording. |
+------------------------+----------------------------------------------------------------------------------------+
| peak_inactive_mem | Peak amount of inactive, non-releaseable memory in gigabytes at the time of recording. |
+------------------------+----------------------------------------------------------------------------------------+
| peak_reserved_mem | Peak amount of reserved memory in gigabytes at the time of recording. |
+------------------------+----------------------------------------------------------------------------------------+
| alloc_retries | Number of failed cudaMalloc calls that result in a cache flush and retry. |
+------------------------+----------------------------------------------------------------------------------------+
Additionally, if `dist_aggregate_batch_interval` is enabled, the `avg`, `min`, and `max` of the
aformentioned statistics are also logged.
Args:
memory_keys (dict[str, str], optional): A dict specifying memory statistics to log. Keys
are the names of memory statistics to log from `torch.cuda.memory_stats()`, and values
are the names they will be logged under. If not provided, the above statistics are
logged. Defaults to None.
dist_aggregate_batch_interval (int, optional): interval for aggregating memory stats across
all nodes. Defaults to None (by default the functionality is disabled).
"""
def __init__(
self,
memory_keys: Optional[dict[str, str]] = None,
dist_aggregate_batch_interval: Optional[int] = None,
) -> None:
self.memory_keys = memory_keys
self.dist_aggregate_batch_interval = dist_aggregate_batch_interval
def on_train_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: pl.utilities.types.STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
""" """
memory_report = {}
memory_report = _get_memory_report(self.memory_keys)
if self.dist_aggregate_batch_interval:
dist_memory_report = {}
for mem_stat, val in memory_report.items():
dist_memory_report[mem_stat + '_avg'] = reduce_value(val, 'avg')
dist_memory_report[mem_stat + '_min'] = reduce_value(val, 'min')
dist_memory_report[mem_stat + '_max'] = reduce_value(val, 'max')
memory_report.update(dist_memory_report)
memory_metrics = {f'memory/{mem_stat}': val for (mem_stat, val) in memory_report.items()}
for metric, value in memory_metrics.items():
self.log(metric, value)
_MEMORY_KEYS = {
'allocated_bytes.all.current': 'current_allocated_mem',
'active_bytes.all.current': 'current_active_mem',
'inactive_split_bytes.all.current': 'current_inactive_mem',
'reserved_bytes.all.current': 'current_reserved_mem',
'allocated_bytes.all.peak': 'peak_allocated_mem',
'active_bytes.all.peak': 'peak_active_mem',
'inactive_split_bytes.all.peak': 'peak_inactive_mem',
'reserved_bytes.all.peak': 'peak_reserved_mem',
'num_alloc_retries': 'alloc_retries',
}
def _get_memory_report(memory_keys: Optional[dict[str, str]] = None) -> dict[str, Union[int, float]]:
"""
Returns a dictionary with memory metrics.
Args:
memory_keys (Optional[dict[str, str]]): a dict specifying memory statistics to log.
Retuns:
dict: memory statistics.
"""
memory_stats = torch.cuda.memory_stats()
memory_keys = memory_keys or _MEMORY_KEYS
# simplify and reformat the memory_stats
memory_report = {}
for torch_name, name in memory_keys.items():
if torch_name in memory_stats:
# Convert to gigabytes
if 'bytes' in torch_name:
gigabytes = memory_stats[torch_name] / 1.0e9
# Round to preserve 5 significant digits
if gigabytes != 0:
order_of_magnitude = int(math.floor(math.log10(abs(gigabytes))))
gigabytes = round(gigabytes, -order_of_magnitude + 4)
memory_report[name.replace('bytes', 'gigabytes')] = gigabytes
else:
memory_report[name] = memory_stats[torch_name]
return memory_report