File size: 1,879 Bytes
101858b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""

Memory Optimization Module

Unified memory management system with shared Qwen model integration.

"""
from .config import MemoryOptimizationConfig
from .manager import UnifiedMemoryManager
from .tensor_pool import TensorPool
from .model_cache import ModelCache
from .cleanup import MemoryCleanup

# Global unified memory manager instance
_unified_memory_manager = None

def get_unified_memory_manager() -> UnifiedMemoryManager:
    """Get the global unified memory manager instance"""
    global _unified_memory_manager
    if _unified_memory_manager is None:
        _unified_memory_manager = UnifiedMemoryManager()
    return _unified_memory_manager

# Convenience functions for backward compatibility
def get_shared_model(model_name: str, model_type: str = "transformer", **kwargs):
    """Get shared model instance"""
    return get_unified_memory_manager().get_shared_model(model_name, model_type, **kwargs)

def get_tensor(shape, dtype=None, requires_grad: bool = False, module_name: str = "default"):
    """Get optimized tensor from unified pool"""
    return get_unified_memory_manager().get_tensor(shape, dtype, requires_grad, module_name)

def return_tensor(tensor, module_name: str = "default") -> None:
    """Return tensor to unified pool"""
    get_unified_memory_manager().return_tensor(tensor, module_name)

def clear_memory() -> None:
    """Clear all memory"""
    get_unified_memory_manager().clear_all_memory()

def get_memory_stats():
    """Get memory statistics"""
    return get_unified_memory_manager().get_memory_stats()

__all__ = [
    'MemoryOptimizationConfig',
    'UnifiedMemoryManager',
    'TensorPool',
    'ModelCache',
    'MemoryCleanup',
    'get_unified_memory_manager',
    'get_shared_model',
    'get_tensor',
    'return_tensor',
    'clear_memory',
    'get_memory_stats',
]