|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from typing import Any, Optional, Union |
|
|
|
|
|
import lightning.pytorch as pl |
|
|
import torch |
|
|
import torch.cuda |
|
|
from torch import distributed |
|
|
|
|
|
|
|
|
def reduce_value( |
|
|
value: Union[int, float], |
|
|
reduce_op: str = 'mean', |
|
|
): |
|
|
""" |
|
|
Reduce a value across distributed processes. |
|
|
|
|
|
Args: |
|
|
value (Union[int, float]): The value to reduce. |
|
|
model_device (torch.device): The device on which the model is located. |
|
|
reduce_op (str, optional): The reduction operation to perform. One of 'mean', 'avg', 'sum', 'min', 'max'. |
|
|
Defaults to 'mean'. |
|
|
""" |
|
|
|
|
|
tensor_value = torch.tensor(value) |
|
|
|
|
|
if reduce_op in ['mean', 'avg', 'sum']: |
|
|
op = distributed.ReduceOp.SUM |
|
|
elif reduce_op == 'min': |
|
|
op = distributed.ReduceOp.MIN |
|
|
elif reduce_op == 'max': |
|
|
op = distributed.ReduceOp.MAX |
|
|
else: |
|
|
raise ValueError(f'{reduce_op=} not supported.') |
|
|
|
|
|
distributed.all_reduce(tensor_value, op=op) |
|
|
if reduce_op in ['mean', 'avg']: |
|
|
tensor_value = tensor_value / distributed.get_world_size() |
|
|
|
|
|
return tensor_value.item() |
|
|
|
|
|
|
|
|
class MemoryMonitor(pl.Callback): |
|
|
""" |
|
|
Logs the memory usage of the model. |
|
|
|
|
|
This callback calls the torch memory stats API for CUDA and reports different memory statistics. |
|
|
|
|
|
Example: |
|
|
import nemo_run as run |
|
|
from nemo.lightning.pytorch.callbacks import MemoryMonitor |
|
|
|
|
|
recipe.trainer.callbacks.append( |
|
|
run.Config(MemoryMonitor) |
|
|
) |
|
|
|
|
|
The memory statistics are logged by the :class:`.Logger` to the following keys as |
|
|
described below. |
|
|
|
|
|
+--------------------------+-------------------------------------------------------------+ |
|
|
| Key | Logged data | |
|
|
+==========================+=============================================================+ |
|
|
| | Several memory usage statistics | |
|
|
| ``memory/{statistic}`` | are logged on | |
|
|
| | :attr:`.Event.AFTER_TRAIN_BATCH` event. | |
|
|
+--------------------------+-------------------------------------------------------------+ |
|
|
|
|
|
The following statistics are recorded: |
|
|
|
|
|
+------------------------+----------------------------------------------------------------------------------------+ |
|
|
| Statistic | Description | |
|
|
+========================+========================================================================================+ |
|
|
| current_allocated_mem | Current amount of allocated memory in gigabytes. | |
|
|
+------------------------+----------------------------------------------------------------------------------------+ |
|
|
| current_active_mem | Current amount of active memory in gigabytes at the time of recording. | |
|
|
+------------------------+----------------------------------------------------------------------------------------+ |
|
|
| current_inactive_mem | Current amount of inactive, non-releaseable memory in gigabytes. | |
|
|
+------------------------+----------------------------------------------------------------------------------------+ |
|
|
| current_reserved_mem | Current amount of reserved memory in gigabytes at the time of recording. | |
|
|
+------------------------+----------------------------------------------------------------------------------------+ |
|
|
| peak_allocated_mem | Peak amount of allocated memory in gigabytes. | |
|
|
+------------------------+----------------------------------------------------------------------------------------+ |
|
|
| peak_active_mem | Peak amount of active memory in gigabytes at the time of recording. | |
|
|
+------------------------+----------------------------------------------------------------------------------------+ |
|
|
| peak_inactive_mem | Peak amount of inactive, non-releaseable memory in gigabytes at the time of recording. | |
|
|
+------------------------+----------------------------------------------------------------------------------------+ |
|
|
| peak_reserved_mem | Peak amount of reserved memory in gigabytes at the time of recording. | |
|
|
+------------------------+----------------------------------------------------------------------------------------+ |
|
|
| alloc_retries | Number of failed cudaMalloc calls that result in a cache flush and retry. | |
|
|
+------------------------+----------------------------------------------------------------------------------------+ |
|
|
|
|
|
Additionally, if `dist_aggregate_batch_interval` is enabled, the `avg`, `min`, and `max` of the |
|
|
aformentioned statistics are also logged. |
|
|
|
|
|
Args: |
|
|
memory_keys (dict[str, str], optional): A dict specifying memory statistics to log. Keys |
|
|
are the names of memory statistics to log from `torch.cuda.memory_stats()`, and values |
|
|
are the names they will be logged under. If not provided, the above statistics are |
|
|
logged. Defaults to None. |
|
|
dist_aggregate_batch_interval (int, optional): interval for aggregating memory stats across |
|
|
all nodes. Defaults to None (by default the functionality is disabled). |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
memory_keys: Optional[dict[str, str]] = None, |
|
|
dist_aggregate_batch_interval: Optional[int] = None, |
|
|
) -> None: |
|
|
self.memory_keys = memory_keys |
|
|
self.dist_aggregate_batch_interval = dist_aggregate_batch_interval |
|
|
|
|
|
def on_train_batch_end( |
|
|
self, |
|
|
trainer: pl.Trainer, |
|
|
pl_module: pl.LightningModule, |
|
|
outputs: pl.utilities.types.STEP_OUTPUT, |
|
|
batch: Any, |
|
|
batch_idx: int, |
|
|
) -> None: |
|
|
""" """ |
|
|
memory_report = {} |
|
|
memory_report = _get_memory_report(self.memory_keys) |
|
|
if self.dist_aggregate_batch_interval: |
|
|
dist_memory_report = {} |
|
|
for mem_stat, val in memory_report.items(): |
|
|
dist_memory_report[mem_stat + '_avg'] = reduce_value(val, 'avg') |
|
|
dist_memory_report[mem_stat + '_min'] = reduce_value(val, 'min') |
|
|
dist_memory_report[mem_stat + '_max'] = reduce_value(val, 'max') |
|
|
memory_report.update(dist_memory_report) |
|
|
|
|
|
memory_metrics = {f'memory/{mem_stat}': val for (mem_stat, val) in memory_report.items()} |
|
|
for metric, value in memory_metrics.items(): |
|
|
self.log(metric, value) |
|
|
|
|
|
|
|
|
_MEMORY_KEYS = { |
|
|
'allocated_bytes.all.current': 'current_allocated_mem', |
|
|
'active_bytes.all.current': 'current_active_mem', |
|
|
'inactive_split_bytes.all.current': 'current_inactive_mem', |
|
|
'reserved_bytes.all.current': 'current_reserved_mem', |
|
|
'allocated_bytes.all.peak': 'peak_allocated_mem', |
|
|
'active_bytes.all.peak': 'peak_active_mem', |
|
|
'inactive_split_bytes.all.peak': 'peak_inactive_mem', |
|
|
'reserved_bytes.all.peak': 'peak_reserved_mem', |
|
|
'num_alloc_retries': 'alloc_retries', |
|
|
} |
|
|
|
|
|
|
|
|
def _get_memory_report(memory_keys: Optional[dict[str, str]] = None) -> dict[str, Union[int, float]]: |
|
|
""" |
|
|
Returns a dictionary with memory metrics. |
|
|
|
|
|
Args: |
|
|
memory_keys (Optional[dict[str, str]]): a dict specifying memory statistics to log. |
|
|
|
|
|
Retuns: |
|
|
dict: memory statistics. |
|
|
""" |
|
|
|
|
|
memory_stats = torch.cuda.memory_stats() |
|
|
memory_keys = memory_keys or _MEMORY_KEYS |
|
|
|
|
|
|
|
|
memory_report = {} |
|
|
for torch_name, name in memory_keys.items(): |
|
|
if torch_name in memory_stats: |
|
|
|
|
|
if 'bytes' in torch_name: |
|
|
gigabytes = memory_stats[torch_name] / 1.0e9 |
|
|
|
|
|
if gigabytes != 0: |
|
|
order_of_magnitude = int(math.floor(math.log10(abs(gigabytes)))) |
|
|
gigabytes = round(gigabytes, -order_of_magnitude + 4) |
|
|
memory_report[name.replace('bytes', 'gigabytes')] = gigabytes |
|
|
else: |
|
|
memory_report[name] = memory_stats[torch_name] |
|
|
|
|
|
return memory_report |
|
|
|