File size: 3,580 Bytes
101858b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""

Tensor Pool Module

Unified tensor pooling system for memory efficiency.

"""
import torch
import logging
from typing import Dict, Tuple, List
from collections import defaultdict

logger = logging.getLogger(__name__)

class TensorPool:
    """

    Unified tensor pool for efficient memory management.

    """
    def __init__(self, max_pool_size: int = 50, max_tensor_size: int = 1000000):
        self.max_pool_size = max_pool_size
        self.max_tensor_size = max_tensor_size
        self.pools = defaultdict(list)
        self.usage_stats = defaultdict(int)
        self.operation_count = 0
        
        logger.debug("TensorPool initialized")
    
    def get_tensor(self, shape: Tuple[int, ...], dtype: torch.dtype = torch.float32,

                   requires_grad: bool = False, device: torch.device = None) -> torch.Tensor:
        """

        Get tensor from pool or create new one.

        

        Args:

            shape: Tensor shape

            dtype: Tensor data type

            requires_grad: Whether tensor requires gradients

            device: Device to create tensor on

            

        Returns:

            Tensor from pool or newly created tensor

        """
        self.operation_count += 1
        key = (shape, dtype, requires_grad)
        
        # Try to get tensor from pool
        if key in self.pools and self.pools[key]:
            tensor = self.pools[key].pop()
            tensor.zero_()  # Clear tensor
            self.usage_stats[key] += 1
            return tensor.to(device) if device else tensor
        
        # Create new tensor
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        tensor = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad)
        self.usage_stats[key] += 1
        
        return tensor
    
    def return_tensor(self, tensor: torch.Tensor) -> None:
        """

        Return tensor to pool for reuse.

        

        Args:

            tensor: Tensor to return to pool

        """
        if tensor is None or not isinstance(tensor, torch.Tensor):
            return
        
        # Don't pool very large tensors
        if tensor.numel() > self.max_tensor_size:
            return
        
        key = (tuple(tensor.shape), tensor.dtype, tensor.requires_grad)
        
        # Only pool if we have space
        if len(self.pools[key]) < self.max_pool_size:
            tensor.detach_()
            self.pools[key].append(tensor)
    
    def clear_pool(self, keep_ratio: float = 0.5) -> None:
        """

        Clear tensor pool, keeping a percentage.

        

        Args:

            keep_ratio: Ratio of pool to keep (0.0 to 1.0)

        """
        for key, pool in self.pools.items():
            if len(pool) > self.max_pool_size * keep_ratio:
                excess = len(pool) - int(self.max_pool_size * keep_ratio)
                for _ in range(excess):
                    if pool:
                        pool.pop()
    
    def clear_all(self) -> None:
        """Clear all tensor pools."""
        self.pools.clear()
        self.usage_stats.clear()
        logger.debug("TensorPool cleared")
    
    def get_stats(self) -> Dict:
        """Get pool statistics."""
        return {
            'pools': {str(k): len(v) for k, v in self.pools.items()},
            'usage_stats': dict(self.usage_stats),
            'operation_count': self.operation_count
        }