""" Tensor Pool Module Unified tensor pooling system for memory efficiency. """ import torch import logging from typing import Dict, Tuple, List from collections import defaultdict logger = logging.getLogger(__name__) class TensorPool: """ Unified tensor pool for efficient memory management. """ def __init__(self, max_pool_size: int = 50, max_tensor_size: int = 1000000): self.max_pool_size = max_pool_size self.max_tensor_size = max_tensor_size self.pools = defaultdict(list) self.usage_stats = defaultdict(int) self.operation_count = 0 logger.debug("TensorPool initialized") def get_tensor(self, shape: Tuple[int, ...], dtype: torch.dtype = torch.float32, requires_grad: bool = False, device: torch.device = None) -> torch.Tensor: """ Get tensor from pool or create new one. Args: shape: Tensor shape dtype: Tensor data type requires_grad: Whether tensor requires gradients device: Device to create tensor on Returns: Tensor from pool or newly created tensor """ self.operation_count += 1 key = (shape, dtype, requires_grad) # Try to get tensor from pool if key in self.pools and self.pools[key]: tensor = self.pools[key].pop() tensor.zero_() # Clear tensor self.usage_stats[key] += 1 return tensor.to(device) if device else tensor # Create new tensor if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tensor = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad) self.usage_stats[key] += 1 return tensor def return_tensor(self, tensor: torch.Tensor) -> None: """ Return tensor to pool for reuse. Args: tensor: Tensor to return to pool """ if tensor is None or not isinstance(tensor, torch.Tensor): return # Don't pool very large tensors if tensor.numel() > self.max_tensor_size: return key = (tuple(tensor.shape), tensor.dtype, tensor.requires_grad) # Only pool if we have space if len(self.pools[key]) < self.max_pool_size: tensor.detach_() self.pools[key].append(tensor) def clear_pool(self, keep_ratio: float = 0.5) -> None: """ Clear tensor pool, keeping a percentage. Args: keep_ratio: Ratio of pool to keep (0.0 to 1.0) """ for key, pool in self.pools.items(): if len(pool) > self.max_pool_size * keep_ratio: excess = len(pool) - int(self.max_pool_size * keep_ratio) for _ in range(excess): if pool: pool.pop() def clear_all(self) -> None: """Clear all tensor pools.""" self.pools.clear() self.usage_stats.clear() logger.debug("TensorPool cleared") def get_stats(self) -> Dict: """Get pool statistics.""" return { 'pools': {str(k): len(v) for k, v in self.pools.items()}, 'usage_stats': dict(self.usage_stats), 'operation_count': self.operation_count }