File size: 4,695 Bytes
3ff2f18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""GPU resource management utilities for ZeroGPU compatibility.

This module provides utilities for managing GPU resources, including model device
transfers, cache management, and context managers for automatic cleanup.
"""

import logging
import time
from contextlib import contextmanager
from typing import Any, Optional

import torch

from src.config.gpu_config import GPUConfig

logger = logging.getLogger(__name__)


def acquire_gpu(model: torch.nn.Module, device: str = "cuda") -> bool:
    """Move a model to the specified GPU device.

    Args:
        model: PyTorch model to move to GPU
        device: Target device (default: "cuda")

    Returns:
        bool: True if successful, False otherwise
    """
    try:
        start_time = time.time()
        target_device = torch.device(device)
        model.to(target_device)
        elapsed = time.time() - start_time

        logger.debug(f"Model {model.__class__.__name__} moved to {device} in {elapsed:.3f}s")
        return True
    except Exception as e:
        logger.error(f"Failed to move model to {device}: {e}")
        return False


def release_gpu(model: torch.nn.Module, clear_cache: bool = True) -> bool:
    """Move a model back to CPU and optionally clear CUDA cache.

    Args:
        model: PyTorch model to move to CPU
        clear_cache: Whether to clear CUDA cache after moving

    Returns:
        bool: True if successful, False otherwise
    """
    try:
        start_time = time.time()
        model.to(torch.device("cpu"))

        if clear_cache and GPUConfig.ENABLE_CACHE_CLEARING and torch.cuda.is_available():
            torch.cuda.empty_cache()

        elapsed = time.time() - start_time

        if elapsed > GPUConfig.CLEANUP_TIMEOUT:
            logger.warning(
                f"GPU cleanup took {elapsed:.3f}s, exceeding {GPUConfig.CLEANUP_TIMEOUT}s limit"
            )
        else:
            logger.debug(f"GPU released in {elapsed:.3f}s")

        return True
    except Exception as e:
        logger.error(f"Failed to release GPU: {e}")
        return False


@contextmanager
def gpu_context(model: torch.nn.Module, device: str = "cuda"):
    """Context manager for automatic GPU resource management.

    Acquires GPU on entry and releases it on exit, even if an exception occurs.

    Args:
        model: PyTorch model to manage
        device: Target GPU device (default: "cuda")

    Yields:
        torch.nn.Module: The model on the GPU device

    Example:
        >>> with gpu_context(my_model) as model:
        ...     result = model(input_data)
    """
    acquired = False
    try:
        acquired = acquire_gpu(model, device)
        if not acquired:
            logger.warning(f"Failed to acquire GPU, model remains on {model.device}")
        yield model
    finally:
        if acquired:
            release_gpu(model, clear_cache=True)


def move_to_device(data: Any, device: torch.device) -> Any:
    """Recursively move tensors to the specified device.

    Handles nested structures like lists, tuples, and dicts.

    Args:
        data: Data to move (tensor, list, tuple, dict, or other)
        device: Target device

    Returns:
        Data with all tensors moved to the device
    """
    if isinstance(data, torch.Tensor):
        return data.to(device)
    elif isinstance(data, dict):
        return {k: move_to_device(v, device) for k, v in data.items()}
    elif isinstance(data, list):
        return [move_to_device(item, device) for item in data]
    elif isinstance(data, tuple):
        return tuple(move_to_device(item, device) for item in data)
    else:
        return data


def get_gpu_memory_info() -> Optional[dict]:
    """Get current GPU memory usage information.

    Returns:
        dict: Memory information with 'allocated' and 'reserved' in GB, or None if CUDA unavailable
    """
    if not torch.cuda.is_available():
        return None

    try:
        allocated = torch.cuda.memory_allocated() / 1024**3  # Convert to GB
        reserved = torch.cuda.memory_reserved() / 1024**3
        return {
            "allocated_gb": round(allocated, 2),
            "reserved_gb": round(reserved, 2),
        }
    except Exception as e:
        logger.error(f"Failed to get GPU memory info: {e}")
        return None


def log_gpu_usage(operation: str):
    """Log current GPU memory usage for a specific operation.

    Args:
        operation: Description of the operation being performed
    """
    memory_info = get_gpu_memory_info()
    if memory_info:
        logger.info(
            f"[{operation}] GPU Memory - Allocated: {memory_info['allocated_gb']}GB, "
            f"Reserved: {memory_info['reserved_gb']}GB"
        )