File size: 1,915 Bytes
16c2a22
 
 
92bb437
1e23279
 
16c2a22
 
 
1e23279
9db586c
1e23279
 
9db586c
 
 
 
 
 
 
92bb437
 
 
 
1e23279
 
92bb437
 
1e23279
92bb437
 
 
 
 
 
 
 
 
 
 
16c2a22
 
 
dc14519
 
16c2a22
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""GPU memory management utilities."""

import gc
import warnings
from typing import Optional, Any

import torch


def clear_gpu_memory(model: Optional[Any] = None, tokenizer: Optional[Any] = None) -> None:
    """Clear GPU memory by emptying CUDA cache and running garbage collection.
    
    This function performs aggressive GPU memory cleanup by:
    1. Clearing CUDA cache
    2. Running multiple garbage collection passes
    
    Important: This function does NOT delete model or tokenizer objects.
    The caller must set their references to None (e.g., `model = None`) 
    for the objects to be garbage collected and GPU memory to be freed.
    
    .. deprecated:: 
        The `model` and `tokenizer` parameters are deprecated and will be removed
        in a future release. The function will become parameterless in the next major
        version. These parameters are no longer used internally.
    
    Args:
        model: Optional model object (deprecated, will be removed in future release)
        tokenizer: Optional tokenizer object (deprecated, will be removed in future release)
    """
    # Emit deprecation warning if parameters are provided
    if model is not None or tokenizer is not None:
        warnings.warn(
            "The 'model' and 'tokenizer' parameters to clear_gpu_memory() are deprecated "
            "and will be removed in a future release. The function will become parameterless "
            "in the next major version. These parameters are no longer used internally. "
            "Simply call clear_gpu_memory() without arguments.",
            DeprecationWarning,
            stacklevel=2
        )
    
    if not torch.cuda.is_available():
        return
    
    # Clear CUDA cache and run garbage collection
    # Single pass is sufficient with modern PyTorch and device_map="auto"
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    gc.collect()