""" KV Cache management for distributed inference. This module provides functionality for managing and rebalancing KV caches across distributed ranks during inference. """ import torch import torch.distributed as dist from typing import List, Dict, Tuple, Optional import logging from .utils import CommunicationTags, CommunicationTimer from .data_containers import KVCacheData, BlockInterval class KVCacheManager: """ Manages KV cache operations for distributed inference. This class handles KV cache broadcasting, rebalancing, and ownership management across distributed ranks. """ def __init__(self, pipeline, device: torch.device): """ Initialize the KV cache manager. Args: pipeline: The inference pipeline containing KV caches device: GPU device for operations """ self.pipeline = pipeline self.device = device self.frame_seq_length = pipeline.frame_seq_length self.time_step_length = len(pipeline.denoising_step_list) # Setup logging self.logger = logging.getLogger(f"KVCacheManager_{device}") self.logger.propagate = False if not self.logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter( f'[KVCacheManager {device}] %(asctime)s - %(name)s - %(levelname)s - %(message)s' ) handler.setFormatter(formatter) self.logger.addHandler(handler) def broadcast_kv_blocks(self, block_indices: List[int], donor_rank: int) -> None: """ Broadcast kv_cache1 entries for the specified block indices from donor_rank to all ranks. This ensures the receiver rank has the up-to-date KV cache when ownership moves. Args: block_indices: List of block indices to broadcast donor_rank: Rank that owns the KV cache data """ if len(block_indices) == 0: return rank = dist.get_rank() with CommunicationTimer(f"broadcast_kv_blocks from rank {donor_rank}", self.logger): for bi in block_indices: # Broadcast key cache if self.pipeline.kv_cache1[bi]['k'].device != self.device: self.pipeline.kv_cache1[bi]['k'] = self.pipeline.kv_cache1[bi]['k'].to(self.device) self.pipeline.kv_cache1[bi]['v'] = self.pipeline.kv_cache1[bi]['v'].to(self.device) dist.barrier() dist.broadcast(self.pipeline.kv_cache1[bi]['k'], src=donor_rank) # Broadcast value cache dist.broadcast(self.pipeline.kv_cache1[bi]['v'], src=donor_rank) # Broadcast global end index dist.broadcast(self.pipeline.kv_cache1[bi]['global_end_index'], src=donor_rank) # Broadcast local end index dist.broadcast(self.pipeline.kv_cache1[bi]['local_end_index'], src=donor_rank) # Adjust global_end_index for the receiving rank if donor_rank > rank: self.pipeline.kv_cache1[bi]['global_end_index'] += self.frame_seq_length * (donor_rank - rank) * self.time_step_length self.logger.debug(f"Broadcasted KV cache for blocks {block_indices} from rank {donor_rank}") def compute_block_owners(self, block_intervals: torch.Tensor, total_blocks: int) -> torch.Tensor: """ Given block intervals in [start, end) format for all ranks, return a tensor where each entry is the owner rank of that block index. Args: block_intervals: Block intervals for all ranks [world_size, 2] total_blocks: Total number of blocks Returns: Tensor of length total_blocks with owner ranks """ world_size = block_intervals.shape[0] owners = torch.full((total_blocks,), -1, dtype=torch.int64, device=block_intervals.device) for r in range(world_size): s = int(block_intervals[r, 0].item()) e = int(block_intervals[r, 1].item()) if e > s: owners[s:e] = r self.logger.debug(f"Computed block owners: {owners.tolist()}") return owners def rebalance_kv_cache_by_diff(self, old_block_intervals: torch.Tensor, new_block_intervals: torch.Tensor, total_blocks: int) -> None: """ Compare ownership from old to new intervals and broadcast KV cache for blocks whose owner changes. For each moved block i, use the previous owner's rank as src to broadcast pipeline.kv_cache1[i]['k'/'v'/...] to all ranks so the new owner has the correct state. Args: old_block_intervals: Previous block intervals [world_size, 2] new_block_intervals: New block intervals [world_size, 2] total_blocks: Total number of blocks """ with CommunicationTimer("rebalance_kv_cache_by_diff", self.logger): old_owners = self.compute_block_owners(old_block_intervals, total_blocks) new_owners = self.compute_block_owners(new_block_intervals, total_blocks) # Find blocks that changed ownership moved_by_src = {} for i in range(total_blocks): o = int(old_owners[i].item()) n = int(new_owners[i].item()) if o != n and o >= 0: if o not in moved_by_src: moved_by_src[o] = [] moved_by_src[o].append(i) # Synchronize before broadcasting dist.barrier() # Broadcast per donor rank (can batch multiple blocks per src) for src, blocks in moved_by_src.items(): self.broadcast_kv_blocks(blocks, donor_rank=src) self.logger.info(f"Rebalanced KV cache: {len(moved_by_src)} ranks had ownership changes") def get_kv_cache_statistics(self, block_intervals: torch.Tensor, total_blocks: int) -> Dict[str, any]: """ Get statistics about KV cache distribution. Args: block_intervals: Current block intervals [world_size, 2] total_blocks: Total number of blocks Returns: Dictionary containing KV cache statistics """ owners = self.compute_block_owners(block_intervals, total_blocks) # Count blocks per rank block_counts = {} for rank in range(block_intervals.shape[0]): block_counts[rank] = int((owners == rank).sum().item()) # Calculate memory usage per rank (approximate) memory_per_block = 0 if hasattr(self.pipeline, 'kv_cache1') and len(self.pipeline.kv_cache1) > 0: # Estimate memory per block based on first block first_block = self.pipeline.kv_cache1[0] if 'k' in first_block and 'v' in first_block: k_memory = first_block['k'].numel() * first_block['k'].element_size() v_memory = first_block['v'].numel() * first_block['v'].element_size() memory_per_block = k_memory + v_memory memory_usage = {rank: block_counts[rank] * memory_per_block for rank in block_counts} return { "block_counts": block_counts, "memory_usage_bytes": memory_usage, "total_blocks": total_blocks, "memory_per_block_bytes": memory_per_block, "frame_seq_length": self.frame_seq_length } def print_kv_cache_statistics(self, block_intervals: torch.Tensor, total_blocks: int) -> None: """ Print KV cache statistics. Args: block_intervals: Current block intervals [world_size, 2] total_blocks: Total number of blocks """ stats = self.get_kv_cache_statistics(block_intervals, total_blocks) self.logger.info("KV Cache Statistics:") self.logger.info(f" Total blocks: {stats['total_blocks']}") self.logger.info(f" Memory per block: {stats['memory_per_block_bytes']} bytes") self.logger.info(f" Frame sequence length: {stats['frame_seq_length']}") self.logger.info(" Block distribution:") for rank, count in stats['block_counts'].items(): memory_mb = stats['memory_usage_bytes'][rank] / (1024 * 1024) self.logger.info(f" Rank {rank}: {count} blocks, {memory_mb:.2f} MB") def validate_kv_cache_consistency(self, block_intervals: torch.Tensor, total_blocks: int) -> bool: """ Validate that KV cache ownership is consistent with block intervals. Args: block_intervals: Current block intervals [world_size, 2] total_blocks: Total number of blocks Returns: True if consistent, False otherwise """ owners = self.compute_block_owners(block_intervals, total_blocks) # Check that all blocks have owners unowned_blocks = (owners == -1).sum().item() if unowned_blocks > 0: self.logger.error(f"Found {unowned_blocks} unowned blocks") return False # Check that block intervals are contiguous and non-overlapping for rank in range(block_intervals.shape[0]): start = int(block_intervals[rank, 0].item()) end = int(block_intervals[rank, 1].item()) if start < 0 or end > total_blocks or start >= end: self.logger.error(f"Invalid block interval for rank {rank}: [{start}, {end})") return False # Check that all blocks in this interval are owned by this rank for block_idx in range(start, end): if int(owners[block_idx].item()) != rank: self.logger.error(f"Block {block_idx} not owned by rank {rank}") return False self.logger.debug("KV cache consistency validation passed") return True def cleanup_kv_cache(self, block_intervals: torch.Tensor, total_blocks: int) -> None: """ Clean up KV cache for blocks not owned by current rank. Args: block_intervals: Current block intervals [world_size, 2] total_blocks: Total number of blocks """ rank = dist.get_rank() owners = self.compute_block_owners(block_intervals, total_blocks) cleaned_blocks = 0 for block_idx in range(total_blocks): if int(owners[block_idx].item()) != rank: # Clear KV cache for blocks not owned by this rank if hasattr(self.pipeline, 'kv_cache1') and block_idx < len(self.pipeline.kv_cache1): if 'k' in self.pipeline.kv_cache1[block_idx]: self.pipeline.kv_cache1[block_idx]['k'].zero_() if 'v' in self.pipeline.kv_cache1[block_idx]: self.pipeline.kv_cache1[block_idx]['v'].zero_() cleaned_blocks += 1 self.logger.info(f"Cleaned up KV cache for {cleaned_blocks} blocks not owned by rank {rank}")