Spaces:
Running on Zero
Running on Zero
File size: 10,754 Bytes
5c93746 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 | """
Model data transfer abstraction layer.
This module provides high-level interfaces for transferring model data
between distributed ranks during inference.
"""
import torch
from typing import List, Tuple, Optional, Any
import logging
from .distributed_communicator import DistributedCommunicator
from .buffer_manager import BufferManager
from .kv_cache_manager import KVCacheManager
from .data_containers import LatentData, CommunicationConfig, PerformanceMetrics
from .utils import CommunicationTimer
class ModelDataTransfer:
"""
High-level interface for model data transfer operations.
This class encapsulates all model-related data transfer operations,
providing a clean interface for sending and receiving latent data,
KV caches, and other model state between ranks.
"""
def __init__(self, communicator: DistributedCommunicator,
buffer_manager: BufferManager,
kv_cache_manager: Optional[KVCacheManager] = None,
config: Optional[CommunicationConfig] = None):
"""
Initialize the model data transfer manager.
Args:
communicator: Distributed communicator instance
buffer_manager: Buffer manager for tensor allocation
kv_cache_manager: KV cache manager (optional)
config: Communication configuration
"""
self.comm = communicator
self.buffer_mgr = buffer_manager
self.kv_cache_mgr = kv_cache_manager
self.config = config or CommunicationConfig()
# Setup logging
self.logger = logging.getLogger(f"ModelDataTransfer_rank_{communicator.rank}")
self.logger.propagate = False
if not self.logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
f'[Rank {communicator.rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
# Performance tracking
self.transfer_count = 0
self.total_transfer_time = 0.0
def send_latent_data_async(self, chunk_idx: int, latents: torch.Tensor,
original_latents: torch.Tensor, patched_x_shape: torch.Tensor,
current_start: torch.Tensor, current_end: torch.Tensor,
current_step: int) -> List[Any]:
"""
Asynchronously send latent data to the next rank.
Args:
chunk_idx: Chunk index
latents: Latent tensor
original_latents: Original latent tensor
patched_x_shape: Patched x shape tensor
current_start: Current start indices
current_end: Current end indices
current_step: Current step
Returns:
List of work objects for all send operations
"""
with CommunicationTimer(f"send_latent_data_async chunk_{chunk_idx}", self.logger):
work_objects = self.comm.send_latent_data_async(
chunk_idx=chunk_idx,
latents=latents,
original_latents=original_latents,
patched_x_shape=patched_x_shape,
current_start=current_start,
current_end=current_end,
current_step=current_step
)
self.transfer_count += 1
self.logger.debug(f"Sent latent data for chunk {chunk_idx}")
return work_objects
def receive_latent_data_async(self, num_steps: int) -> LatentData:
"""
Asynchronously receive latent data from the previous rank.
Args:
num_steps: Number of denoising steps
Returns:
LatentData object containing all received data
"""
with CommunicationTimer("receive_latent_data_async", self.logger):
chunk_idx, latents, original_latents, current_start, current_end, current_step, patched_x_shape = \
self.comm.recv_latent_data_async(num_steps, self.buffer_mgr)
self.transfer_count += 1
self.logger.debug(f"Received latent data for chunk {chunk_idx}")
return LatentData(
chunk_idx=chunk_idx,
latents=latents,
original_latents=original_latents,
current_start=current_start,
current_end=current_end,
current_step=current_step,
patched_x_shape=patched_x_shape
)
def release_latent_data(self, latent_data: Optional[LatentData]) -> None:
"""Return received latent-data buffers to the buffer pool."""
if latent_data is None or self.buffer_mgr is None:
return
self.buffer_mgr.return_buffer(latent_data.latents, "latent")
self.buffer_mgr.return_buffer(latent_data.original_latents, "origin")
self.buffer_mgr.return_buffer(latent_data.patched_x_shape, "misc")
self.buffer_mgr.return_buffer(latent_data.current_start, "misc")
self.buffer_mgr.return_buffer(latent_data.current_end, "misc")
def send_prompt_async(self, prompt: str, device: torch.device) -> List[Any]:
return self.comm.send_prompt_async(prompt, device)
def recv_prompt_async(self) -> str:
return self.comm.recv_prompt_async()
def send_kv_cache_blocks(self, block_indices: List[int], donor_rank: int) -> None:
"""
Send KV cache blocks to all ranks.
Args:
block_indices: List of block indices to send
donor_rank: Rank that owns the KV cache data
"""
if self.kv_cache_mgr is None:
raise RuntimeError("KV cache manager not initialized")
with CommunicationTimer(f"send_kv_cache_blocks {len(block_indices)} blocks", self.logger):
self.kv_cache_mgr.broadcast_kv_blocks(block_indices, donor_rank)
self.logger.debug(f"Sent KV cache blocks {block_indices} from rank {donor_rank}")
def rebalance_kv_cache(self, old_intervals: torch.Tensor,
new_intervals: torch.Tensor, total_blocks: int) -> None:
"""
Rebalance KV cache ownership based on new block intervals.
Args:
old_intervals: Previous block intervals [world_size, 2]
new_intervals: New block intervals [world_size, 2]
total_blocks: Total number of blocks
"""
if self.kv_cache_mgr is None:
raise RuntimeError("KV cache manager not initialized")
with CommunicationTimer("rebalance_kv_cache", self.logger):
self.kv_cache_mgr.rebalance_kv_cache_by_diff(old_intervals, new_intervals, total_blocks)
self.logger.info("Rebalanced KV cache ownership")
def broadcast_tensor(self, tensor: torch.Tensor, src: int) -> None:
"""
Broadcast a tensor from source to all ranks.
Args:
tensor: Tensor to broadcast
src: Source rank
"""
with CommunicationTimer(f"broadcast_tensor from rank {src}", self.logger):
self.comm.broadcast_tensor(tensor, src)
self.logger.debug(f"Broadcasted tensor from rank {src}, shape: {tensor.shape}")
def all_gather_tensors(self, tensor: torch.Tensor) -> List[torch.Tensor]:
"""
Gather tensors from all ranks.
Args:
tensor: Local tensor to gather
Returns:
List of tensors from all ranks
"""
with CommunicationTimer("all_gather_tensors", self.logger):
gather_list = self.comm.all_gather_tensors(tensor)
self.logger.debug(f"Gathered tensors from all ranks, local shape: {tensor.shape}")
return gather_list
def wait_for_outstanding(self, max_outstanding: Optional[int] = None) -> None:
"""
Wait for outstanding operations to complete.
Args:
max_outstanding: Maximum number of outstanding operations to keep
"""
with CommunicationTimer("wait_for_outstanding", self.logger):
self.comm.wait_for_outstanding(max_outstanding)
def barrier(self) -> None:
"""Synchronize all ranks."""
with CommunicationTimer("barrier", self.logger):
self.comm.barrier()
def get_performance_metrics(self) -> PerformanceMetrics:
"""
Get performance metrics for data transfer operations.
Returns:
PerformanceMetrics object containing timing information
"""
# This is a simplified version - in practice, you'd want to track
# more detailed timing information
avg_transfer_time = self.total_transfer_time / max(1, self.transfer_count)
return PerformanceMetrics(
dit_time=0.0, # Would be filled by caller
total_time=0.0, # Would be filled by caller
communication_time=avg_transfer_time,
buffer_allocation_time=0.0 # Would be tracked by buffer manager
)
def get_statistics(self) -> dict:
"""
Get transfer statistics.
Returns:
Dictionary containing transfer statistics
"""
return {
"transfer_count": self.transfer_count,
"total_transfer_time": self.total_transfer_time,
"avg_transfer_time": self.total_transfer_time / max(1, self.transfer_count),
"communicator_stats": self.comm.get_statistics(),
"buffer_manager_stats": self.buffer_mgr.get_statistics() if self.buffer_mgr else None
}
def print_statistics(self) -> None:
"""Print transfer statistics."""
stats = self.get_statistics()
self.logger.info("Model Data Transfer Statistics:")
for key, value in stats.items():
if key == "communicator_stats" or key == "buffer_manager_stats":
if value:
self.logger.info(f" {key}:")
for sub_key, sub_value in value.items():
self.logger.info(f" {sub_key}: {sub_value}")
else:
self.logger.info(f" {key}: {value}")
def cleanup(self) -> None:
"""Clean up resources."""
if self.buffer_mgr:
self.buffer_mgr.clear_buffers()
self.logger.info("Model data transfer cleanup completed")
def __del__(self):
"""Cleanup when the transfer manager is destroyed."""
try:
self.cleanup()
except Exception:
pass
|