""" Model data transfer abstraction layer. This module provides high-level interfaces for transferring model data between distributed ranks during inference. """ import torch from typing import List, Tuple, Optional, Any import logging from .distributed_communicator import DistributedCommunicator from .buffer_manager import BufferManager from .kv_cache_manager import KVCacheManager from .data_containers import LatentData, CommunicationConfig, PerformanceMetrics from .utils import CommunicationTimer class ModelDataTransfer: """ High-level interface for model data transfer operations. This class encapsulates all model-related data transfer operations, providing a clean interface for sending and receiving latent data, KV caches, and other model state between ranks. """ def __init__(self, communicator: DistributedCommunicator, buffer_manager: BufferManager, kv_cache_manager: Optional[KVCacheManager] = None, config: Optional[CommunicationConfig] = None): """ Initialize the model data transfer manager. Args: communicator: Distributed communicator instance buffer_manager: Buffer manager for tensor allocation kv_cache_manager: KV cache manager (optional) config: Communication configuration """ self.comm = communicator self.buffer_mgr = buffer_manager self.kv_cache_mgr = kv_cache_manager self.config = config or CommunicationConfig() # Setup logging self.logger = logging.getLogger(f"ModelDataTransfer_rank_{communicator.rank}") self.logger.propagate = False if not self.logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter( f'[Rank {communicator.rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s' ) handler.setFormatter(formatter) self.logger.addHandler(handler) # Performance tracking self.transfer_count = 0 self.total_transfer_time = 0.0 def send_latent_data_async(self, chunk_idx: int, latents: torch.Tensor, original_latents: torch.Tensor, patched_x_shape: torch.Tensor, current_start: torch.Tensor, current_end: torch.Tensor, current_step: int) -> List[Any]: """ Asynchronously send latent data to the next rank. Args: chunk_idx: Chunk index latents: Latent tensor original_latents: Original latent tensor patched_x_shape: Patched x shape tensor current_start: Current start indices current_end: Current end indices current_step: Current step Returns: List of work objects for all send operations """ with CommunicationTimer(f"send_latent_data_async chunk_{chunk_idx}", self.logger): work_objects = self.comm.send_latent_data_async( chunk_idx=chunk_idx, latents=latents, original_latents=original_latents, patched_x_shape=patched_x_shape, current_start=current_start, current_end=current_end, current_step=current_step ) self.transfer_count += 1 self.logger.debug(f"Sent latent data for chunk {chunk_idx}") return work_objects def receive_latent_data_async(self, num_steps: int) -> LatentData: """ Asynchronously receive latent data from the previous rank. Args: num_steps: Number of denoising steps Returns: LatentData object containing all received data """ with CommunicationTimer("receive_latent_data_async", self.logger): chunk_idx, latents, original_latents, current_start, current_end, current_step, patched_x_shape = \ self.comm.recv_latent_data_async(num_steps, self.buffer_mgr) self.transfer_count += 1 self.logger.debug(f"Received latent data for chunk {chunk_idx}") return LatentData( chunk_idx=chunk_idx, latents=latents, original_latents=original_latents, current_start=current_start, current_end=current_end, current_step=current_step, patched_x_shape=patched_x_shape ) def release_latent_data(self, latent_data: Optional[LatentData]) -> None: """Return received latent-data buffers to the buffer pool.""" if latent_data is None or self.buffer_mgr is None: return self.buffer_mgr.return_buffer(latent_data.latents, "latent") self.buffer_mgr.return_buffer(latent_data.original_latents, "origin") self.buffer_mgr.return_buffer(latent_data.patched_x_shape, "misc") self.buffer_mgr.return_buffer(latent_data.current_start, "misc") self.buffer_mgr.return_buffer(latent_data.current_end, "misc") def send_prompt_async(self, prompt: str, device: torch.device) -> List[Any]: return self.comm.send_prompt_async(prompt, device) def recv_prompt_async(self) -> str: return self.comm.recv_prompt_async() def send_kv_cache_blocks(self, block_indices: List[int], donor_rank: int) -> None: """ Send KV cache blocks to all ranks. Args: block_indices: List of block indices to send donor_rank: Rank that owns the KV cache data """ if self.kv_cache_mgr is None: raise RuntimeError("KV cache manager not initialized") with CommunicationTimer(f"send_kv_cache_blocks {len(block_indices)} blocks", self.logger): self.kv_cache_mgr.broadcast_kv_blocks(block_indices, donor_rank) self.logger.debug(f"Sent KV cache blocks {block_indices} from rank {donor_rank}") def rebalance_kv_cache(self, old_intervals: torch.Tensor, new_intervals: torch.Tensor, total_blocks: int) -> None: """ Rebalance KV cache ownership based on new block intervals. Args: old_intervals: Previous block intervals [world_size, 2] new_intervals: New block intervals [world_size, 2] total_blocks: Total number of blocks """ if self.kv_cache_mgr is None: raise RuntimeError("KV cache manager not initialized") with CommunicationTimer("rebalance_kv_cache", self.logger): self.kv_cache_mgr.rebalance_kv_cache_by_diff(old_intervals, new_intervals, total_blocks) self.logger.info("Rebalanced KV cache ownership") def broadcast_tensor(self, tensor: torch.Tensor, src: int) -> None: """ Broadcast a tensor from source to all ranks. Args: tensor: Tensor to broadcast src: Source rank """ with CommunicationTimer(f"broadcast_tensor from rank {src}", self.logger): self.comm.broadcast_tensor(tensor, src) self.logger.debug(f"Broadcasted tensor from rank {src}, shape: {tensor.shape}") def all_gather_tensors(self, tensor: torch.Tensor) -> List[torch.Tensor]: """ Gather tensors from all ranks. Args: tensor: Local tensor to gather Returns: List of tensors from all ranks """ with CommunicationTimer("all_gather_tensors", self.logger): gather_list = self.comm.all_gather_tensors(tensor) self.logger.debug(f"Gathered tensors from all ranks, local shape: {tensor.shape}") return gather_list def wait_for_outstanding(self, max_outstanding: Optional[int] = None) -> None: """ Wait for outstanding operations to complete. Args: max_outstanding: Maximum number of outstanding operations to keep """ with CommunicationTimer("wait_for_outstanding", self.logger): self.comm.wait_for_outstanding(max_outstanding) def barrier(self) -> None: """Synchronize all ranks.""" with CommunicationTimer("barrier", self.logger): self.comm.barrier() def get_performance_metrics(self) -> PerformanceMetrics: """ Get performance metrics for data transfer operations. Returns: PerformanceMetrics object containing timing information """ # This is a simplified version - in practice, you'd want to track # more detailed timing information avg_transfer_time = self.total_transfer_time / max(1, self.transfer_count) return PerformanceMetrics( dit_time=0.0, # Would be filled by caller total_time=0.0, # Would be filled by caller communication_time=avg_transfer_time, buffer_allocation_time=0.0 # Would be tracked by buffer manager ) def get_statistics(self) -> dict: """ Get transfer statistics. Returns: Dictionary containing transfer statistics """ return { "transfer_count": self.transfer_count, "total_transfer_time": self.total_transfer_time, "avg_transfer_time": self.total_transfer_time / max(1, self.transfer_count), "communicator_stats": self.comm.get_statistics(), "buffer_manager_stats": self.buffer_mgr.get_statistics() if self.buffer_mgr else None } def print_statistics(self) -> None: """Print transfer statistics.""" stats = self.get_statistics() self.logger.info("Model Data Transfer Statistics:") for key, value in stats.items(): if key == "communicator_stats" or key == "buffer_manager_stats": if value: self.logger.info(f" {key}:") for sub_key, sub_value in value.items(): self.logger.info(f" {sub_key}: {sub_value}") else: self.logger.info(f" {key}: {value}") def cleanup(self) -> None: """Clean up resources.""" if self.buffer_mgr: self.buffer_mgr.clear_buffers() self.logger.info("Model data transfer cleanup completed") def __del__(self): """Cleanup when the transfer manager is destroyed.""" try: self.cleanup() except Exception: pass