multimodalart's picture
multimodalart HF Staff
Upload folder using huggingface_hub
5c93746 verified
Raw
History Blame Contribute Delete
10.8 kB
"""
Model data transfer abstraction layer.
This module provides high-level interfaces for transferring model data
between distributed ranks during inference.
"""
import torch
from typing import List, Tuple, Optional, Any
import logging
from .distributed_communicator import DistributedCommunicator
from .buffer_manager import BufferManager
from .kv_cache_manager import KVCacheManager
from .data_containers import LatentData, CommunicationConfig, PerformanceMetrics
from .utils import CommunicationTimer
class ModelDataTransfer:
"""
High-level interface for model data transfer operations.
This class encapsulates all model-related data transfer operations,
providing a clean interface for sending and receiving latent data,
KV caches, and other model state between ranks.
"""
def __init__(self, communicator: DistributedCommunicator,
buffer_manager: BufferManager,
kv_cache_manager: Optional[KVCacheManager] = None,
config: Optional[CommunicationConfig] = None):
"""
Initialize the model data transfer manager.
Args:
communicator: Distributed communicator instance
buffer_manager: Buffer manager for tensor allocation
kv_cache_manager: KV cache manager (optional)
config: Communication configuration
"""
self.comm = communicator
self.buffer_mgr = buffer_manager
self.kv_cache_mgr = kv_cache_manager
self.config = config or CommunicationConfig()
# Setup logging
self.logger = logging.getLogger(f"ModelDataTransfer_rank_{communicator.rank}")
self.logger.propagate = False
if not self.logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
f'[Rank {communicator.rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
# Performance tracking
self.transfer_count = 0
self.total_transfer_time = 0.0
def send_latent_data_async(self, chunk_idx: int, latents: torch.Tensor,
original_latents: torch.Tensor, patched_x_shape: torch.Tensor,
current_start: torch.Tensor, current_end: torch.Tensor,
current_step: int) -> List[Any]:
"""
Asynchronously send latent data to the next rank.
Args:
chunk_idx: Chunk index
latents: Latent tensor
original_latents: Original latent tensor
patched_x_shape: Patched x shape tensor
current_start: Current start indices
current_end: Current end indices
current_step: Current step
Returns:
List of work objects for all send operations
"""
with CommunicationTimer(f"send_latent_data_async chunk_{chunk_idx}", self.logger):
work_objects = self.comm.send_latent_data_async(
chunk_idx=chunk_idx,
latents=latents,
original_latents=original_latents,
patched_x_shape=patched_x_shape,
current_start=current_start,
current_end=current_end,
current_step=current_step
)
self.transfer_count += 1
self.logger.debug(f"Sent latent data for chunk {chunk_idx}")
return work_objects
def receive_latent_data_async(self, num_steps: int) -> LatentData:
"""
Asynchronously receive latent data from the previous rank.
Args:
num_steps: Number of denoising steps
Returns:
LatentData object containing all received data
"""
with CommunicationTimer("receive_latent_data_async", self.logger):
chunk_idx, latents, original_latents, current_start, current_end, current_step, patched_x_shape = \
self.comm.recv_latent_data_async(num_steps, self.buffer_mgr)
self.transfer_count += 1
self.logger.debug(f"Received latent data for chunk {chunk_idx}")
return LatentData(
chunk_idx=chunk_idx,
latents=latents,
original_latents=original_latents,
current_start=current_start,
current_end=current_end,
current_step=current_step,
patched_x_shape=patched_x_shape
)
def release_latent_data(self, latent_data: Optional[LatentData]) -> None:
"""Return received latent-data buffers to the buffer pool."""
if latent_data is None or self.buffer_mgr is None:
return
self.buffer_mgr.return_buffer(latent_data.latents, "latent")
self.buffer_mgr.return_buffer(latent_data.original_latents, "origin")
self.buffer_mgr.return_buffer(latent_data.patched_x_shape, "misc")
self.buffer_mgr.return_buffer(latent_data.current_start, "misc")
self.buffer_mgr.return_buffer(latent_data.current_end, "misc")
def send_prompt_async(self, prompt: str, device: torch.device) -> List[Any]:
return self.comm.send_prompt_async(prompt, device)
def recv_prompt_async(self) -> str:
return self.comm.recv_prompt_async()
def send_kv_cache_blocks(self, block_indices: List[int], donor_rank: int) -> None:
"""
Send KV cache blocks to all ranks.
Args:
block_indices: List of block indices to send
donor_rank: Rank that owns the KV cache data
"""
if self.kv_cache_mgr is None:
raise RuntimeError("KV cache manager not initialized")
with CommunicationTimer(f"send_kv_cache_blocks {len(block_indices)} blocks", self.logger):
self.kv_cache_mgr.broadcast_kv_blocks(block_indices, donor_rank)
self.logger.debug(f"Sent KV cache blocks {block_indices} from rank {donor_rank}")
def rebalance_kv_cache(self, old_intervals: torch.Tensor,
new_intervals: torch.Tensor, total_blocks: int) -> None:
"""
Rebalance KV cache ownership based on new block intervals.
Args:
old_intervals: Previous block intervals [world_size, 2]
new_intervals: New block intervals [world_size, 2]
total_blocks: Total number of blocks
"""
if self.kv_cache_mgr is None:
raise RuntimeError("KV cache manager not initialized")
with CommunicationTimer("rebalance_kv_cache", self.logger):
self.kv_cache_mgr.rebalance_kv_cache_by_diff(old_intervals, new_intervals, total_blocks)
self.logger.info("Rebalanced KV cache ownership")
def broadcast_tensor(self, tensor: torch.Tensor, src: int) -> None:
"""
Broadcast a tensor from source to all ranks.
Args:
tensor: Tensor to broadcast
src: Source rank
"""
with CommunicationTimer(f"broadcast_tensor from rank {src}", self.logger):
self.comm.broadcast_tensor(tensor, src)
self.logger.debug(f"Broadcasted tensor from rank {src}, shape: {tensor.shape}")
def all_gather_tensors(self, tensor: torch.Tensor) -> List[torch.Tensor]:
"""
Gather tensors from all ranks.
Args:
tensor: Local tensor to gather
Returns:
List of tensors from all ranks
"""
with CommunicationTimer("all_gather_tensors", self.logger):
gather_list = self.comm.all_gather_tensors(tensor)
self.logger.debug(f"Gathered tensors from all ranks, local shape: {tensor.shape}")
return gather_list
def wait_for_outstanding(self, max_outstanding: Optional[int] = None) -> None:
"""
Wait for outstanding operations to complete.
Args:
max_outstanding: Maximum number of outstanding operations to keep
"""
with CommunicationTimer("wait_for_outstanding", self.logger):
self.comm.wait_for_outstanding(max_outstanding)
def barrier(self) -> None:
"""Synchronize all ranks."""
with CommunicationTimer("barrier", self.logger):
self.comm.barrier()
def get_performance_metrics(self) -> PerformanceMetrics:
"""
Get performance metrics for data transfer operations.
Returns:
PerformanceMetrics object containing timing information
"""
# This is a simplified version - in practice, you'd want to track
# more detailed timing information
avg_transfer_time = self.total_transfer_time / max(1, self.transfer_count)
return PerformanceMetrics(
dit_time=0.0, # Would be filled by caller
total_time=0.0, # Would be filled by caller
communication_time=avg_transfer_time,
buffer_allocation_time=0.0 # Would be tracked by buffer manager
)
def get_statistics(self) -> dict:
"""
Get transfer statistics.
Returns:
Dictionary containing transfer statistics
"""
return {
"transfer_count": self.transfer_count,
"total_transfer_time": self.total_transfer_time,
"avg_transfer_time": self.total_transfer_time / max(1, self.transfer_count),
"communicator_stats": self.comm.get_statistics(),
"buffer_manager_stats": self.buffer_mgr.get_statistics() if self.buffer_mgr else None
}
def print_statistics(self) -> None:
"""Print transfer statistics."""
stats = self.get_statistics()
self.logger.info("Model Data Transfer Statistics:")
for key, value in stats.items():
if key == "communicator_stats" or key == "buffer_manager_stats":
if value:
self.logger.info(f" {key}:")
for sub_key, sub_value in value.items():
self.logger.info(f" {sub_key}: {sub_value}")
else:
self.logger.info(f" {key}: {value}")
def cleanup(self) -> None:
"""Clean up resources."""
if self.buffer_mgr:
self.buffer_mgr.clear_buffers()
self.logger.info("Model data transfer cleanup completed")
def __del__(self):
"""Cleanup when the transfer manager is destroyed."""
try:
self.cleanup()
except Exception:
pass