Phillnet-2 / memory_optimization /tensor_pool.py
ayjays132's picture
Upload 478 files
101858b verified
"""
Tensor Pool Module
Unified tensor pooling system for memory efficiency.
"""
import torch
import logging
from typing import Dict, Tuple, List
from collections import defaultdict
logger = logging.getLogger(__name__)
class TensorPool:
"""
Unified tensor pool for efficient memory management.
"""
def __init__(self, max_pool_size: int = 50, max_tensor_size: int = 1000000):
self.max_pool_size = max_pool_size
self.max_tensor_size = max_tensor_size
self.pools = defaultdict(list)
self.usage_stats = defaultdict(int)
self.operation_count = 0
logger.debug("TensorPool initialized")
def get_tensor(self, shape: Tuple[int, ...], dtype: torch.dtype = torch.float32,
requires_grad: bool = False, device: torch.device = None) -> torch.Tensor:
"""
Get tensor from pool or create new one.
Args:
shape: Tensor shape
dtype: Tensor data type
requires_grad: Whether tensor requires gradients
device: Device to create tensor on
Returns:
Tensor from pool or newly created tensor
"""
self.operation_count += 1
key = (shape, dtype, requires_grad)
# Try to get tensor from pool
if key in self.pools and self.pools[key]:
tensor = self.pools[key].pop()
tensor.zero_() # Clear tensor
self.usage_stats[key] += 1
return tensor.to(device) if device else tensor
# Create new tensor
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tensor = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad)
self.usage_stats[key] += 1
return tensor
def return_tensor(self, tensor: torch.Tensor) -> None:
"""
Return tensor to pool for reuse.
Args:
tensor: Tensor to return to pool
"""
if tensor is None or not isinstance(tensor, torch.Tensor):
return
# Don't pool very large tensors
if tensor.numel() > self.max_tensor_size:
return
key = (tuple(tensor.shape), tensor.dtype, tensor.requires_grad)
# Only pool if we have space
if len(self.pools[key]) < self.max_pool_size:
tensor.detach_()
self.pools[key].append(tensor)
def clear_pool(self, keep_ratio: float = 0.5) -> None:
"""
Clear tensor pool, keeping a percentage.
Args:
keep_ratio: Ratio of pool to keep (0.0 to 1.0)
"""
for key, pool in self.pools.items():
if len(pool) > self.max_pool_size * keep_ratio:
excess = len(pool) - int(self.max_pool_size * keep_ratio)
for _ in range(excess):
if pool:
pool.pop()
def clear_all(self) -> None:
"""Clear all tensor pools."""
self.pools.clear()
self.usage_stats.clear()
logger.debug("TensorPool cleared")
def get_stats(self) -> Dict:
"""Get pool statistics."""
return {
'pools': {str(k): len(v) for k, v in self.pools.items()},
'usage_stats': dict(self.usage_stats),
'operation_count': self.operation_count
}