multimodalart's picture
multimodalart HF Staff
Upload folder using huggingface_hub
5c93746 verified
Raw
History Blame Contribute Delete
11.5 kB
"""
KV Cache management for distributed inference.
This module provides functionality for managing and rebalancing KV caches
across distributed ranks during inference.
"""
import torch
import torch.distributed as dist
from typing import List, Dict, Tuple, Optional
import logging
from .utils import CommunicationTags, CommunicationTimer
from .data_containers import KVCacheData, BlockInterval
class KVCacheManager:
"""
Manages KV cache operations for distributed inference.
This class handles KV cache broadcasting, rebalancing, and ownership
management across distributed ranks.
"""
def __init__(self, pipeline, device: torch.device):
"""
Initialize the KV cache manager.
Args:
pipeline: The inference pipeline containing KV caches
device: GPU device for operations
"""
self.pipeline = pipeline
self.device = device
self.frame_seq_length = pipeline.frame_seq_length
self.time_step_length = len(pipeline.denoising_step_list)
# Setup logging
self.logger = logging.getLogger(f"KVCacheManager_{device}")
self.logger.propagate = False
if not self.logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
f'[KVCacheManager {device}] %(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
def broadcast_kv_blocks(self, block_indices: List[int], donor_rank: int) -> None:
"""
Broadcast kv_cache1 entries for the specified block indices from donor_rank to all ranks.
This ensures the receiver rank has the up-to-date KV cache when ownership moves.
Args:
block_indices: List of block indices to broadcast
donor_rank: Rank that owns the KV cache data
"""
if len(block_indices) == 0:
return
rank = dist.get_rank()
with CommunicationTimer(f"broadcast_kv_blocks from rank {donor_rank}", self.logger):
for bi in block_indices:
# Broadcast key cache
if self.pipeline.kv_cache1[bi]['k'].device != self.device:
self.pipeline.kv_cache1[bi]['k'] = self.pipeline.kv_cache1[bi]['k'].to(self.device)
self.pipeline.kv_cache1[bi]['v'] = self.pipeline.kv_cache1[bi]['v'].to(self.device)
dist.barrier()
dist.broadcast(self.pipeline.kv_cache1[bi]['k'], src=donor_rank)
# Broadcast value cache
dist.broadcast(self.pipeline.kv_cache1[bi]['v'], src=donor_rank)
# Broadcast global end index
dist.broadcast(self.pipeline.kv_cache1[bi]['global_end_index'], src=donor_rank)
# Broadcast local end index
dist.broadcast(self.pipeline.kv_cache1[bi]['local_end_index'], src=donor_rank)
# Adjust global_end_index for the receiving rank
if donor_rank > rank:
self.pipeline.kv_cache1[bi]['global_end_index'] += self.frame_seq_length * (donor_rank - rank) * self.time_step_length
self.logger.debug(f"Broadcasted KV cache for blocks {block_indices} from rank {donor_rank}")
def compute_block_owners(self, block_intervals: torch.Tensor, total_blocks: int) -> torch.Tensor:
"""
Given block intervals in [start, end) format for all ranks, return a tensor
where each entry is the owner rank of that block index.
Args:
block_intervals: Block intervals for all ranks [world_size, 2]
total_blocks: Total number of blocks
Returns:
Tensor of length total_blocks with owner ranks
"""
world_size = block_intervals.shape[0]
owners = torch.full((total_blocks,), -1, dtype=torch.int64, device=block_intervals.device)
for r in range(world_size):
s = int(block_intervals[r, 0].item())
e = int(block_intervals[r, 1].item())
if e > s:
owners[s:e] = r
self.logger.debug(f"Computed block owners: {owners.tolist()}")
return owners
def rebalance_kv_cache_by_diff(self, old_block_intervals: torch.Tensor,
new_block_intervals: torch.Tensor, total_blocks: int) -> None:
"""
Compare ownership from old to new intervals and broadcast KV cache for blocks whose owner changes.
For each moved block i, use the previous owner's rank as src to broadcast
pipeline.kv_cache1[i]['k'/'v'/...] to all ranks so the new owner has the correct state.
Args:
old_block_intervals: Previous block intervals [world_size, 2]
new_block_intervals: New block intervals [world_size, 2]
total_blocks: Total number of blocks
"""
with CommunicationTimer("rebalance_kv_cache_by_diff", self.logger):
old_owners = self.compute_block_owners(old_block_intervals, total_blocks)
new_owners = self.compute_block_owners(new_block_intervals, total_blocks)
# Find blocks that changed ownership
moved_by_src = {}
for i in range(total_blocks):
o = int(old_owners[i].item())
n = int(new_owners[i].item())
if o != n and o >= 0:
if o not in moved_by_src:
moved_by_src[o] = []
moved_by_src[o].append(i)
# Synchronize before broadcasting
dist.barrier()
# Broadcast per donor rank (can batch multiple blocks per src)
for src, blocks in moved_by_src.items():
self.broadcast_kv_blocks(blocks, donor_rank=src)
self.logger.info(f"Rebalanced KV cache: {len(moved_by_src)} ranks had ownership changes")
def get_kv_cache_statistics(self, block_intervals: torch.Tensor, total_blocks: int) -> Dict[str, any]:
"""
Get statistics about KV cache distribution.
Args:
block_intervals: Current block intervals [world_size, 2]
total_blocks: Total number of blocks
Returns:
Dictionary containing KV cache statistics
"""
owners = self.compute_block_owners(block_intervals, total_blocks)
# Count blocks per rank
block_counts = {}
for rank in range(block_intervals.shape[0]):
block_counts[rank] = int((owners == rank).sum().item())
# Calculate memory usage per rank (approximate)
memory_per_block = 0
if hasattr(self.pipeline, 'kv_cache1') and len(self.pipeline.kv_cache1) > 0:
# Estimate memory per block based on first block
first_block = self.pipeline.kv_cache1[0]
if 'k' in first_block and 'v' in first_block:
k_memory = first_block['k'].numel() * first_block['k'].element_size()
v_memory = first_block['v'].numel() * first_block['v'].element_size()
memory_per_block = k_memory + v_memory
memory_usage = {rank: block_counts[rank] * memory_per_block for rank in block_counts}
return {
"block_counts": block_counts,
"memory_usage_bytes": memory_usage,
"total_blocks": total_blocks,
"memory_per_block_bytes": memory_per_block,
"frame_seq_length": self.frame_seq_length
}
def print_kv_cache_statistics(self, block_intervals: torch.Tensor, total_blocks: int) -> None:
"""
Print KV cache statistics.
Args:
block_intervals: Current block intervals [world_size, 2]
total_blocks: Total number of blocks
"""
stats = self.get_kv_cache_statistics(block_intervals, total_blocks)
self.logger.info("KV Cache Statistics:")
self.logger.info(f" Total blocks: {stats['total_blocks']}")
self.logger.info(f" Memory per block: {stats['memory_per_block_bytes']} bytes")
self.logger.info(f" Frame sequence length: {stats['frame_seq_length']}")
self.logger.info(" Block distribution:")
for rank, count in stats['block_counts'].items():
memory_mb = stats['memory_usage_bytes'][rank] / (1024 * 1024)
self.logger.info(f" Rank {rank}: {count} blocks, {memory_mb:.2f} MB")
def validate_kv_cache_consistency(self, block_intervals: torch.Tensor, total_blocks: int) -> bool:
"""
Validate that KV cache ownership is consistent with block intervals.
Args:
block_intervals: Current block intervals [world_size, 2]
total_blocks: Total number of blocks
Returns:
True if consistent, False otherwise
"""
owners = self.compute_block_owners(block_intervals, total_blocks)
# Check that all blocks have owners
unowned_blocks = (owners == -1).sum().item()
if unowned_blocks > 0:
self.logger.error(f"Found {unowned_blocks} unowned blocks")
return False
# Check that block intervals are contiguous and non-overlapping
for rank in range(block_intervals.shape[0]):
start = int(block_intervals[rank, 0].item())
end = int(block_intervals[rank, 1].item())
if start < 0 or end > total_blocks or start >= end:
self.logger.error(f"Invalid block interval for rank {rank}: [{start}, {end})")
return False
# Check that all blocks in this interval are owned by this rank
for block_idx in range(start, end):
if int(owners[block_idx].item()) != rank:
self.logger.error(f"Block {block_idx} not owned by rank {rank}")
return False
self.logger.debug("KV cache consistency validation passed")
return True
def cleanup_kv_cache(self, block_intervals: torch.Tensor, total_blocks: int) -> None:
"""
Clean up KV cache for blocks not owned by current rank.
Args:
block_intervals: Current block intervals [world_size, 2]
total_blocks: Total number of blocks
"""
rank = dist.get_rank()
owners = self.compute_block_owners(block_intervals, total_blocks)
cleaned_blocks = 0
for block_idx in range(total_blocks):
if int(owners[block_idx].item()) != rank:
# Clear KV cache for blocks not owned by this rank
if hasattr(self.pipeline, 'kv_cache1') and block_idx < len(self.pipeline.kv_cache1):
if 'k' in self.pipeline.kv_cache1[block_idx]:
self.pipeline.kv_cache1[block_idx]['k'].zero_()
if 'v' in self.pipeline.kv_cache1[block_idx]:
self.pipeline.kv_cache1[block_idx]['v'].zero_()
cleaned_blocks += 1
self.logger.info(f"Cleaned up KV cache for {cleaned_blocks} blocks not owned by rank {rank}")