File size: 8,978 Bytes
b386992 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Optional, Union
import lightning.pytorch as pl
import torch
import torch.cuda
from torch import distributed
def reduce_value(
value: Union[int, float],
reduce_op: str = 'mean',
):
"""
Reduce a value across distributed processes.
Args:
value (Union[int, float]): The value to reduce.
model_device (torch.device): The device on which the model is located.
reduce_op (str, optional): The reduction operation to perform. One of 'mean', 'avg', 'sum', 'min', 'max'.
Defaults to 'mean'.
"""
tensor_value = torch.tensor(value)
if reduce_op in ['mean', 'avg', 'sum']:
op = distributed.ReduceOp.SUM
elif reduce_op == 'min':
op = distributed.ReduceOp.MIN
elif reduce_op == 'max':
op = distributed.ReduceOp.MAX
else:
raise ValueError(f'{reduce_op=} not supported.')
distributed.all_reduce(tensor_value, op=op)
if reduce_op in ['mean', 'avg']:
tensor_value = tensor_value / distributed.get_world_size()
return tensor_value.item()
class MemoryMonitor(pl.Callback):
"""
Logs the memory usage of the model.
This callback calls the torch memory stats API for CUDA and reports different memory statistics.
Example:
import nemo_run as run
from nemo.lightning.pytorch.callbacks import MemoryMonitor
recipe.trainer.callbacks.append(
run.Config(MemoryMonitor)
)
The memory statistics are logged by the :class:`.Logger` to the following keys as
described below.
+--------------------------+-------------------------------------------------------------+
| Key | Logged data |
+==========================+=============================================================+
| | Several memory usage statistics |
| ``memory/{statistic}`` | are logged on |
| | :attr:`.Event.AFTER_TRAIN_BATCH` event. |
+--------------------------+-------------------------------------------------------------+
The following statistics are recorded:
+------------------------+----------------------------------------------------------------------------------------+
| Statistic | Description |
+========================+========================================================================================+
| current_allocated_mem | Current amount of allocated memory in gigabytes. |
+------------------------+----------------------------------------------------------------------------------------+
| current_active_mem | Current amount of active memory in gigabytes at the time of recording. |
+------------------------+----------------------------------------------------------------------------------------+
| current_inactive_mem | Current amount of inactive, non-releaseable memory in gigabytes. |
+------------------------+----------------------------------------------------------------------------------------+
| current_reserved_mem | Current amount of reserved memory in gigabytes at the time of recording. |
+------------------------+----------------------------------------------------------------------------------------+
| peak_allocated_mem | Peak amount of allocated memory in gigabytes. |
+------------------------+----------------------------------------------------------------------------------------+
| peak_active_mem | Peak amount of active memory in gigabytes at the time of recording. |
+------------------------+----------------------------------------------------------------------------------------+
| peak_inactive_mem | Peak amount of inactive, non-releaseable memory in gigabytes at the time of recording. |
+------------------------+----------------------------------------------------------------------------------------+
| peak_reserved_mem | Peak amount of reserved memory in gigabytes at the time of recording. |
+------------------------+----------------------------------------------------------------------------------------+
| alloc_retries | Number of failed cudaMalloc calls that result in a cache flush and retry. |
+------------------------+----------------------------------------------------------------------------------------+
Additionally, if `dist_aggregate_batch_interval` is enabled, the `avg`, `min`, and `max` of the
aformentioned statistics are also logged.
Args:
memory_keys (dict[str, str], optional): A dict specifying memory statistics to log. Keys
are the names of memory statistics to log from `torch.cuda.memory_stats()`, and values
are the names they will be logged under. If not provided, the above statistics are
logged. Defaults to None.
dist_aggregate_batch_interval (int, optional): interval for aggregating memory stats across
all nodes. Defaults to None (by default the functionality is disabled).
"""
def __init__(
self,
memory_keys: Optional[dict[str, str]] = None,
dist_aggregate_batch_interval: Optional[int] = None,
) -> None:
self.memory_keys = memory_keys
self.dist_aggregate_batch_interval = dist_aggregate_batch_interval
def on_train_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: pl.utilities.types.STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
""" """
memory_report = {}
memory_report = _get_memory_report(self.memory_keys)
if self.dist_aggregate_batch_interval:
dist_memory_report = {}
for mem_stat, val in memory_report.items():
dist_memory_report[mem_stat + '_avg'] = reduce_value(val, 'avg')
dist_memory_report[mem_stat + '_min'] = reduce_value(val, 'min')
dist_memory_report[mem_stat + '_max'] = reduce_value(val, 'max')
memory_report.update(dist_memory_report)
memory_metrics = {f'memory/{mem_stat}': val for (mem_stat, val) in memory_report.items()}
for metric, value in memory_metrics.items():
self.log(metric, value)
_MEMORY_KEYS = {
'allocated_bytes.all.current': 'current_allocated_mem',
'active_bytes.all.current': 'current_active_mem',
'inactive_split_bytes.all.current': 'current_inactive_mem',
'reserved_bytes.all.current': 'current_reserved_mem',
'allocated_bytes.all.peak': 'peak_allocated_mem',
'active_bytes.all.peak': 'peak_active_mem',
'inactive_split_bytes.all.peak': 'peak_inactive_mem',
'reserved_bytes.all.peak': 'peak_reserved_mem',
'num_alloc_retries': 'alloc_retries',
}
def _get_memory_report(memory_keys: Optional[dict[str, str]] = None) -> dict[str, Union[int, float]]:
"""
Returns a dictionary with memory metrics.
Args:
memory_keys (Optional[dict[str, str]]): a dict specifying memory statistics to log.
Retuns:
dict: memory statistics.
"""
memory_stats = torch.cuda.memory_stats()
memory_keys = memory_keys or _MEMORY_KEYS
# simplify and reformat the memory_stats
memory_report = {}
for torch_name, name in memory_keys.items():
if torch_name in memory_stats:
# Convert to gigabytes
if 'bytes' in torch_name:
gigabytes = memory_stats[torch_name] / 1.0e9
# Round to preserve 5 significant digits
if gigabytes != 0:
order_of_magnitude = int(math.floor(math.log10(abs(gigabytes))))
gigabytes = round(gigabytes, -order_of_magnitude + 4)
memory_report[name.replace('bytes', 'gigabytes')] = gigabytes
else:
memory_report[name] = memory_stats[torch_name]
return memory_report
|