File size: 5,479 Bytes
a0d6949
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import gc
import torch
from typing import Optional


class ResourceManager:
    """
    Centralized GPU resource management for model lifecycle.
    Ensures efficient memory usage and prevents OOM errors.
    """

    def __init__(self):
        """Initialize resource manager and detect device."""
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.current_model = None
        self.model_registry = {}

    def get_device(self) -> str:
        """
        Get current compute device.

        Returns:
            str: 'cuda' or 'cpu'
        """
        return self.device

    def register_model(self, model_name: str, model_instance: object) -> None:
        """
        Register a model instance for tracking.

        Args:
            model_name (str): Identifier for the model
            model_instance (object): The model object
        """
        self.model_registry[model_name] = model_instance
        self.current_model = model_name
        print(f"✓ Model registered: {model_name}")

    def unregister_model(self, model_name: str) -> None:
        """
        Unregister and cleanup a specific model.

        Args:
            model_name (str): Identifier of model to remove
        """
        if model_name in self.model_registry:
            del self.model_registry[model_name]
            if self.current_model == model_name:
                self.current_model = None
            print(f"✓ Model unregistered: {model_name}")

    def clear_cache(self, aggressive: bool = False) -> None:
        """
        Clear CUDA cache and run garbage collection.

        Args:
            aggressive (bool): If True, performs additional cleanup steps
        """
        if self.device == "cuda":
            # Standard cache clearing
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

            if aggressive:
                # Aggressive cleanup for critical memory situations
                gc.collect()
                with torch.cuda.device(self.device):
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize()

            print(f"✓ CUDA cache cleared (aggressive={aggressive})")
        else:
            gc.collect()
            print("✓ CPU memory garbage collected")

    def cleanup_model(self, model_instance: Optional[object] = None) -> None:
        """
        Safely cleanup a model instance and free GPU memory.

        Args:
            model_instance (Optional[object]): Specific model to cleanup.
                                               If None, cleans all registered models.
        """
        if model_instance is not None:
            # Cleanup specific model
            if hasattr(model_instance, 'to'):
                model_instance.to('cpu')
            del model_instance
        else:
            # Cleanup all registered models
            for name, model in list(self.model_registry.items()):
                if hasattr(model, 'to'):
                    model.to('cpu')
                del model
                self.unregister_model(name)

        # Force cleanup
        gc.collect()
        self.clear_cache(aggressive=True)
        print("✓ Model cleanup completed")

    def get_memory_stats(self) -> dict:
        """
        Get current GPU memory statistics.

        Returns:
            dict: Memory statistics (allocated, reserved, free)
        """
        if self.device == "cuda" and torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1024**3  # GB
            reserved = torch.cuda.memory_reserved() / 1024**3    # GB
            total = torch.cuda.get_device_properties(0).total_memory / 1024**3

            return {
                "allocated_gb": round(allocated, 2),
                "reserved_gb": round(reserved, 2),
                "total_gb": round(total, 2),
                "free_gb": round(total - allocated, 2)
            }
        else:
            return {
                "allocated_gb": 0,
                "reserved_gb": 0,
                "total_gb": 0,
                "free_gb": 0
            }

    def ensure_memory_available(self, required_gb: float = 2.0) -> bool:
        """
        Check if sufficient GPU memory is available.

        Args:
            required_gb (float): Required memory in GB

        Returns:
            bool: True if memory is available, False otherwise
        """
        stats = self.get_memory_stats()
        available = stats["free_gb"]

        if available < required_gb:
            print(f"⚠ Low memory: {available:.2f}GB available, {required_gb:.2f}GB required")
            # Attempt cleanup
            self.clear_cache(aggressive=True)
            stats = self.get_memory_stats()
            available = stats["free_gb"]

        return available >= required_gb

    def switch_model_context(self, from_model: str, to_model: str) -> None:
        """
        Handle model switching with proper cleanup.

        Args:
            from_model (str): Current model to unload
            to_model (str): Next model to prepare for
        """
        print(f"→ Switching context: {from_model}{to_model}")

        # Unregister old model
        self.unregister_model(from_model)

        # Aggressive cleanup before loading new model
        self.clear_cache(aggressive=True)

        # Update current model tracker
        self.current_model = to_model
        print(f"✓ Context switched to {to_model}")