# Copyright 2025 Bytedance Ltd. and/or its affiliates # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gc import inspect import logging import os from datetime import datetime from pathlib import Path import torch from verl.utils.device import get_torch_device, is_cuda_available logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) def aggressive_empty_cache(force_sync: bool = True, max_retries: int = 3) -> None: """ More aggressive GPU memory cleanup function, tries to release PyTorch reserved but unallocated memory. Args: force_sync: Whether to force device synchronization max_retries: Maximum number of retries """ device = get_torch_device() if not device.is_available(): return for attempt in range(max_retries): # Record memory status before cleanup before_reserved = device.memory_reserved() before_allocated = device.memory_allocated() # Run garbage collection gc.collect() # Clear PyTorch cache device.empty_cache() # Force synchronization (optional) if force_sync: device.synchronize() # Record memory status after cleanup after_reserved = device.memory_reserved() after_allocated = device.memory_allocated() # Calculate freed memory reserved_freed = before_reserved - after_reserved allocated_freed = before_allocated - after_allocated logger.info( f"Memory cleanup attempt {attempt + 1}: Freed {reserved_freed / 1024**3:.2f} GB reserved, " f"{allocated_freed / 1024**3:.2f} GB allocated" ) # Stop retrying if little memory was freed if reserved_freed < 1024**3: # less than 1GB break def reset_memory_stats() -> None: """Reset GPU memory statistics""" if get_torch_device().is_available(): device = get_torch_device() device.reset_peak_memory_stats() device.reset_accumulated_memory_stats() def get_memory_info() -> dict: """Get detailed GPU memory information""" if not get_torch_device().is_available(): return {} device = get_torch_device() device_id = device.current_device() return { "total_memory_gb": device.get_device_properties(device_id).total_memory / 1024**3, "reserved_memory_gb": device.memory_reserved() / 1024**3, "allocated_memory_gb": device.memory_allocated() / 1024**3, "cached_memory_gb": (device.memory_reserved() - device.memory_allocated()) / 1024**3, "max_memory_allocated_gb": device.max_memory_allocated() / 1024**3, "max_memory_reserved_gb": device.max_memory_reserved() / 1024**3, } def log_memory_usage(stage: str = "current") -> None: """Log GPU memory usage""" if not get_torch_device().is_available(): return info = get_memory_info() logger.info( f"Memory usage [{stage}]: " f"Total: {info['total_memory_gb']:.2f} GB, " f"Allocated: {info['allocated_memory_gb']:.2f} GB, " f"Reserved: {info['reserved_memory_gb']:.2f} GB, " f"Cached: {info['cached_memory_gb']:.2f} GB" ) def optimize_memory_for_inference() -> None: """Optimize GPU memory usage for inference""" if not get_torch_device().is_available(): return # Set a more aggressive memory allocation policy get_torch_device().set_per_process_memory_fraction(0.95) # Use 95% of GPU memory # Clear cache aggressive_empty_cache(force_sync=True) logger.info("Optimized GPU memory usage for inference") def optimize_memory_for_training() -> None: """Optimize GPU memory usage for training""" if not get_torch_device().is_available(): return # Set a moderate memory allocation policy get_torch_device().set_per_process_memory_fraction(0.9) # Use 90% of GPU memory # Clear cache aggressive_empty_cache(force_sync=False) logger.info("Optimized GPU memory usage for training") def enable_memory_visualize( trace_alloc_max_entries: int = 200_000, stack_depth: int = 32, context: str = "all", stacks: str = "all", devices=None, record_context: bool = True, ): """ Enables memory history recording for CUDA allocations. This function should be called before any large-scale CUDA allocations. For DDP or multi-process setups, it must be called on each rank. Args: trace_alloc_max_entries (int): Maximum number of allocation entries to record. stack_depth (int): The depth of the call stack to capture for each allocation. (Supported by some PyTorch versions). context (str): The type of memory events to record. 'alloc': records only allocation events. 'state': records memory state changes. 'all': records both. stacks (str): The type of call stacks to record. 'python': records Python stacks. 'cpp': records C++ stacks (available in some versions). 'all': records both. devices (Union[int, list[int], None]): The device for which to enable memory history. `None` enables it for the current default device. record_context (bool): Whether to record context information for allocations. Required by older PyTorch versions. """ # Memory history recording is CUDA-specific functionality if not is_cuda_available: logger.warning("[memory_visualize] Memory history recording is only available on CUDA devices") return f = get_torch_device().memory._record_memory_history params = set(inspect.signature(f).parameters.keys()) def _one_call(dev_kw=None): kwargs = {} if "context" in params: kwargs["context"] = context if "stacks" in params: kwargs["stacks"] = stacks if "max_entries" in params: kwargs["max_entries"] = trace_alloc_max_entries elif "trace_alloc_max_entries" in params: kwargs["trace_alloc_max_entries"] = trace_alloc_max_entries if "stack_depth" in params: kwargs["stack_depth"] = stack_depth if dev_kw is not None: if "device" in params: kwargs["device"] = dev_kw elif "devices" in params: kwargs["devices"] = dev_kw if isinstance(dev_kw, list) else [dev_kw] if "record_context" in params: kwargs["record_context"] = record_context try: f(**kwargs) return "native", kwargs except TypeError: try: if "trace_alloc_max_entries" in params and "record_context" in params: f(enabled=True, trace_alloc_max_entries=trace_alloc_max_entries, record_context=True) return "legacy", { "enabled": True, "trace_alloc_max_entries": trace_alloc_max_entries, "record_context": True, } else: f(enabled=True) return "legacy-min", {"enabled": True} except Exception: raise if devices is None or isinstance(devices, str | int | torch.device): mode, used = _one_call(devices if devices is not None else None) else: mode, used = "multi-device", {} for d in list(devices): _mode, _used = _one_call(d) used[f"dev{d}"] = _used device = get_torch_device() if device.is_available(): device.reset_peak_memory_stats() device.synchronize() rank = int(os.environ.get("RANK", "0") or 0) logger.info(f"[memory_visualize][rank {rank}] recording enabled ({mode}); args={used}") class MemorySnapshotSampler: """ A utility class that dumps GPU memory snapshots. This is useful for monitoring memory usage over a long-running process. The dumped files can be visualized with https://docs.pytorch.org/memory_viz Args: out_dir (str): The directory where the snapshots will be saved. tag (str): A tag for the snapshot filenames. """ def __init__(self, out_dir: str = "./mem_snapshots", tag: str = "periodic"): self.out_dir = out_dir self.tag = tag def dump_memory_snapshot(self, out_dir: str = "./mem_snapshots", tag: str = "snapshot", sub_dir: str = None): """ Generates a memory snapshot and saves it as a pickle file in a specified directory. The files are organized by timestamp in subdirectories, with all ranks' files placed in the same timestamp subdirectory. Args: out_dir (str): The directory where the snapshot file will be saved. The directory is created if it does not exist. tag (str): A string tag to prepend to the filename for easier identification. sub_dir (str): A subdirectory to place the snapshot file in. """ if sub_dir is None: timestamp = datetime.now().strftime("%Y%m%d-%H%M") out_path = Path(out_dir) / timestamp else: out_path = Path(out_dir) / sub_dir out_path.mkdir(parents=True, exist_ok=True) # get the GPU rank on the current process rank = os.environ.get("RANK", "0") pid = os.getpid() # todo(chenyang): check wether we need to sync all ranks before dump fname = f"{tag}_rank{rank}_pid{pid}.pickle" path = out_path / fname device = get_torch_device() if not device.is_available(): logger.warning("[memory_visualize] is only available on CUDA devices.") return try: device.synchronize() # Memory snapshot is CUDA-specific functionality device.memory._dump_snapshot(str(path)) logger.info(f"[memory_visualize] dumped: {path}") except Exception as e: logger.info(f"[memory_visualize][warn] dump failed: {e}")