File size: 1,915 Bytes
16c2a22 92bb437 1e23279 16c2a22 1e23279 9db586c 1e23279 9db586c 92bb437 1e23279 92bb437 1e23279 92bb437 16c2a22 dc14519 16c2a22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
"""GPU memory management utilities."""
import gc
import warnings
from typing import Optional, Any
import torch
def clear_gpu_memory(model: Optional[Any] = None, tokenizer: Optional[Any] = None) -> None:
"""Clear GPU memory by emptying CUDA cache and running garbage collection.
This function performs aggressive GPU memory cleanup by:
1. Clearing CUDA cache
2. Running multiple garbage collection passes
Important: This function does NOT delete model or tokenizer objects.
The caller must set their references to None (e.g., `model = None`)
for the objects to be garbage collected and GPU memory to be freed.
.. deprecated::
The `model` and `tokenizer` parameters are deprecated and will be removed
in a future release. The function will become parameterless in the next major
version. These parameters are no longer used internally.
Args:
model: Optional model object (deprecated, will be removed in future release)
tokenizer: Optional tokenizer object (deprecated, will be removed in future release)
"""
# Emit deprecation warning if parameters are provided
if model is not None or tokenizer is not None:
warnings.warn(
"The 'model' and 'tokenizer' parameters to clear_gpu_memory() are deprecated "
"and will be removed in a future release. The function will become parameterless "
"in the next major version. These parameters are no longer used internally. "
"Simply call clear_gpu_memory() without arguments.",
DeprecationWarning,
stacklevel=2
)
if not torch.cuda.is_available():
return
# Clear CUDA cache and run garbage collection
# Single pass is sufficient with modern PyTorch and device_map="auto"
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
|