""" Buffer manager for efficient GPU memory management. This module provides a buffer pool manager to avoid repeated GPU memory allocations during distributed communication operations. """ import torch from typing import Dict, List, Tuple, Optional import threading import logging from .data_containers import CommunicationConfig BufferKey = Tuple[Tuple[int, ...], torch.dtype] class BufferManager: """ Manages GPU buffer pools to avoid repeated allocations. This class maintains pools of pre-allocated GPU tensors that can be reused across communication operations, reducing memory allocation overhead. """ def __init__(self, device: torch.device, config: Optional[CommunicationConfig] = None): """ Initialize the buffer manager. Args: device: GPU device for buffer allocation config: Communication configuration """ self.device = device self.config = config or CommunicationConfig() # Buffer pools: {(shape, dtype): [tensor1, tensor2, ...]} self.free_buffers: Dict[BufferKey, List[torch.Tensor]] = {} self.free_buffers_origin: Dict[BufferKey, List[torch.Tensor]] = {} self.free_buffers_kv: Dict[BufferKey, List[torch.Tensor]] = {} self.free_buffers_misc: Dict[BufferKey, List[torch.Tensor]] = {} # Thread safety self._lock = threading.Lock() # Statistics self.allocation_count = 0 self.reuse_count = 0 self.total_allocated_memory = 0 # Setup logging self.logger = logging.getLogger(f"BufferManager_{device}") self.logger.propagate = False if not self.logger.handlers: handler = logging.StreamHandler() # handler.setLevel(logging.DEBUG) formatter = logging.Formatter( f'[BufferManager {device}] %(asctime)s - %(name)s - %(levelname)s - %(message)s' ) handler.setFormatter(formatter) self.logger.addHandler(handler) # self.logger.setLevel(logging.DEBUG) def get_buffer(self, shape: Tuple[int, ...], dtype: torch.dtype, buffer_type: str = "latent") -> torch.Tensor: """ Get or allocate a buffer with the specified shape and dtype. Args: shape: Tensor shape dtype: Tensor data type buffer_type: Type of buffer ("latent", "origin", "kv") Returns: Tensor buffer """ with self._lock: # Select the appropriate buffer pool if buffer_type == "latent": buffer_pool = self.free_buffers elif buffer_type == "origin": buffer_pool = self.free_buffers_origin elif buffer_type == "kv": buffer_pool = self.free_buffers_kv elif buffer_type == "misc": buffer_pool = self.free_buffers_misc else: raise ValueError(f"Unknown buffer type: {buffer_type}") # Try to reuse existing buffer key = (tuple(shape), dtype) if self.config.enable_buffer_reuse and key in buffer_pool and len(buffer_pool[key]) > 0: buffer = buffer_pool[key].pop() self.reuse_count += 1 self.logger.debug(f"Reused buffer of shape {shape}, dtype {dtype}, type {buffer_type}") return buffer # Allocate new buffer buffer = torch.empty(shape, dtype=dtype, device=self.device) self.allocation_count += 1 self.total_allocated_memory += buffer.numel() * buffer.element_size() self.logger.debug(f"Allocated new buffer of shape {shape}, dtype {dtype}, type {buffer_type}") return buffer def return_buffer(self, tensor: torch.Tensor, buffer_type: str = "latent") -> None: """ Return a buffer to the pool for reuse. Args: tensor: Tensor to return buffer_type: Type of buffer ("latent", "origin", "kv") """ if not self.config.enable_buffer_reuse: return with self._lock: # Select the appropriate buffer pool if buffer_type == "latent": buffer_pool = self.free_buffers elif buffer_type == "origin": buffer_pool = self.free_buffers_origin elif buffer_type == "kv": buffer_pool = self.free_buffers_kv elif buffer_type == "misc": buffer_pool = self.free_buffers_misc else: raise ValueError(f"Unknown buffer type: {buffer_type}") key = (tuple(tensor.shape), tensor.dtype) # Initialize pool for this shape if it doesn't exist if key not in buffer_pool: buffer_pool[key] = [] # Add buffer to pool if not at capacity if len(buffer_pool[key]) < self.config.buffer_pool_size: buffer_pool[key].append(tensor) self.logger.debug( f"Returned buffer of shape {tuple(tensor.shape)}, dtype {tensor.dtype}, type {buffer_type}" ) else: self.logger.debug( f"Buffer pool full for shape {tuple(tensor.shape)}, dtype {tensor.dtype}, type {buffer_type}, discarding" ) def clear_buffers(self, buffer_type: Optional[str] = None) -> None: """ Clear buffer pools to free memory. Args: buffer_type: Specific buffer type to clear, or None to clear all """ with self._lock: if buffer_type is None: # Clear all buffer pools self.free_buffers.clear() self.free_buffers_origin.clear() self.free_buffers_kv.clear() self.free_buffers_misc.clear() self.logger.info("Cleared all buffer pools") else: # Clear specific buffer pool if buffer_type == "latent": self.free_buffers.clear() elif buffer_type == "origin": self.free_buffers_origin.clear() elif buffer_type == "kv": self.free_buffers_kv.clear() elif buffer_type == "misc": self.free_buffers_misc.clear() else: raise ValueError(f"Unknown buffer type: {buffer_type}") self.logger.info(f"Cleared {buffer_type} buffer pool") def get_statistics(self) -> Dict[str, any]: """ Get buffer manager statistics. Returns: Dictionary containing statistics """ with self._lock: total_free_buffers = sum(len(pool) for pool in self.free_buffers.values()) total_free_buffers_origin = sum(len(pool) for pool in self.free_buffers_origin.values()) total_free_buffers_kv = sum(len(pool) for pool in self.free_buffers_kv.values()) total_free_buffers_misc = sum(len(pool) for pool in self.free_buffers_misc.values()) return { "allocation_count": self.allocation_count, "reuse_count": self.reuse_count, "total_allocated_memory_bytes": self.total_allocated_memory, "total_free_buffers": total_free_buffers, "total_free_buffers_origin": total_free_buffers_origin, "total_free_buffers_kv": total_free_buffers_kv, "total_free_buffers_misc": total_free_buffers_misc, "reuse_rate": self.reuse_count / max(1, self.allocation_count), "buffer_pool_size": self.config.buffer_pool_size, "enable_buffer_reuse": self.config.enable_buffer_reuse } def print_statistics(self) -> None: """Print buffer manager statistics.""" stats = self.get_statistics() self.logger.info("Buffer Manager Statistics:") for key, value in stats.items(): self.logger.info(f" {key}: {value}") def preallocate_buffers(self, common_shapes: List[Tuple[Tuple[int, ...], torch.dtype, str]], count_per_shape: int = 5) -> None: """ Preallocate buffers for common shapes to reduce allocation overhead. Args: common_shapes: List of (shape, dtype, buffer_type) tuples count_per_shape: Number of buffers to preallocate per shape """ with self._lock: for shape, dtype, buffer_type in common_shapes: for _ in range(count_per_shape): buffer = torch.empty(shape, dtype=dtype, device=self.device) # Select the appropriate buffer pool if buffer_type == "latent": buffer_pool = self.free_buffers elif buffer_type == "origin": buffer_pool = self.free_buffers_origin elif buffer_type == "kv": buffer_pool = self.free_buffers_kv elif buffer_type == "misc": buffer_pool = self.free_buffers_misc else: raise ValueError(f"Unknown buffer type: {buffer_type}") # Initialize pool for this shape if it doesn't exist key = (tuple(shape), dtype) if key not in buffer_pool: buffer_pool[key] = [] buffer_pool[key].append(buffer) self.allocation_count += 1 self.total_allocated_memory += buffer.numel() * buffer.element_size() self.logger.info(f"Preallocated {len(common_shapes) * count_per_shape} buffers") def __del__(self): """Cleanup when the buffer manager is destroyed.""" try: self.clear_buffers() except Exception: pass