Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from lightning.pytorch.callbacks.callback import Callback
from torch.utils.viz._cycles import warn_tensor_cycles
from nemo.lightning import io
from nemo.utils import logging
from nemo.utils.get_rank import get_rank
class MemoryProfileCallback(Callback, io.IOMixin):
"""
This callback enables recording a timeline of memory allocations during training.
The generated .pickle profiles can be analyzed at https://pytorch.org/memory_viz
More info about the profiles can be found [here](https://pytorch.org/blog/understanding-gpu-memory-1/).
Args:
dir (Optional[str]): Directory to store the memory profile dump
warn_cycles (Optional[bool]): Whether to enable [reference cycle detection](https://pytorch.org/blog/understanding-gpu-memory-2/)
rank (Optional[list[int]]): List of ranks to collect snapshot on, defaults to all if list is empty
Example:
>>> callback = MemoryProfileCallback(dir="/mem_profile", ranks=[0])
>>> trainer = Trainer(callbacks=[callback])
"""
def __init__(self, dir: str = "/mem_profile", warn_cycles=True, ranks=[]):
self.dir = dir
self.ranks = ranks
os.makedirs(self.dir, exist_ok=True)
logging.info(f"Torch memory profiles will be written to: {self.dir}")
if warn_cycles:
logging.info("Enabling reference cycle detector")
warn_tensor_cycles()
def enable_on_rank(self) -> bool:
if not self.ranks:
return True
return get_rank() in self.ranks
def setup(self, trainer, pl_module, stage) -> None:
"""PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end
We use it here to start recording the memory profiler.
"""
if trainer.max_steps > 1000:
logging.warning(
f"Memory profiling creates snapshots during the entire training process, \
where every iteration increases the size of the snapshot. \
Try reducing trainer.max_steps to avoid running into issues"
)
if torch.distributed.is_initialized() and self.enable_on_rank():
torch.cuda.memory._record_memory_history(max_entries=100000)
def on_train_end(self, trainer, pl_module) -> None:
"""PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end
We use it here to finish memory profiling and write the snapshot.
"""
logging.info(
f"on_train_batch_end rank: {get_rank()} mem: {torch.cuda.memory_allocated()/1024/1024/1024} / {torch.cuda.max_memory_reserved()/1024/1024/1024}"
)
if torch.distributed.is_initialized() and self.enable_on_rank():
rank = get_rank()
_snapshot_path = f"{self.dir}/memory_snapshot-rank{rank}.pickle"
logging.info(f"Writing memory profile snapshot to {_snapshot_path}")
torch.cuda.memory._dump_snapshot(f"{_snapshot_path}")
torch.cuda.memory._record_memory_history(enabled=None)
logging.info(f"Finished writing memory profile snapshot: {_snapshot_path}")