multimodalart's picture
multimodalart HF Staff
Upload folder using huggingface_hub
5c93746 verified
Raw
History Blame Contribute Delete
5.03 kB
"""
Data containers for communication operations.
This module defines data structures used for communication between distributed ranks.
"""
from dataclasses import dataclass
from typing import Optional, List, Dict, Any
import torch
@dataclass
class LatentData:
"""
Container for latent data and related information.
This class encapsulates all the data that needs to be transferred between ranks
during the inference pipeline.
"""
chunk_idx: int
latents: torch.Tensor
original_latents: torch.Tensor
current_start: torch.Tensor
current_end: torch.Tensor
current_step: int
patched_x_shape: torch.Tensor
def __post_init__(self):
"""Validate tensor shapes and types after initialization."""
if not isinstance(self.latents, torch.Tensor):
raise TypeError("latents must be a torch.Tensor")
if not isinstance(self.original_latents, torch.Tensor):
raise TypeError("original_latents must be a torch.Tensor")
if not isinstance(self.current_start, torch.Tensor):
raise TypeError("current_start must be a torch.Tensor")
if not isinstance(self.current_end, torch.Tensor):
raise TypeError("current_end must be a torch.Tensor")
if not isinstance(self.patched_x_shape, torch.Tensor):
raise TypeError("patched_x_shape must be a torch.Tensor")
@dataclass
class KVCacheData:
"""
Container for KV cache data.
This class encapsulates key-value cache information for transformer blocks.
"""
block_index: int
k_cache: torch.Tensor
v_cache: torch.Tensor
global_end_index: torch.Tensor
local_end_index: torch.Tensor
def __post_init__(self):
"""Validate tensor shapes and types after initialization."""
if not isinstance(self.k_cache, torch.Tensor):
raise TypeError("k_cache must be a torch.Tensor")
if not isinstance(self.v_cache, torch.Tensor):
raise TypeError("v_cache must be a torch.Tensor")
if not isinstance(self.global_end_index, torch.Tensor):
raise TypeError("global_end_index must be a torch.Tensor")
if not isinstance(self.local_end_index, torch.Tensor):
raise TypeError("local_end_index must be a torch.Tensor")
@dataclass
class CommunicationConfig:
"""
Configuration for communication operations.
This class holds configuration parameters for distributed communication.
"""
max_outstanding: int = 1
buffer_pool_size: int = 10
enable_buffer_reuse: bool = True
communication_timeout: float = 30.0
enable_async_communication: bool = True
def __post_init__(self):
"""Validate configuration parameters."""
if self.max_outstanding < 1:
raise ValueError("max_outstanding must be at least 1")
if self.buffer_pool_size < 1:
raise ValueError("buffer_pool_size must be at least 1")
if self.communication_timeout <= 0:
raise ValueError("communication_timeout must be positive")
@dataclass
class BlockInterval:
"""
Container for block interval information.
This class represents a block interval [start, end) for a specific rank.
"""
start: int
end: int
rank: int
def __post_init__(self):
"""Validate block interval parameters."""
if self.start < 0:
raise ValueError("start must be non-negative")
if self.end <= self.start:
raise ValueError("end must be greater than start")
if self.rank < 0:
raise ValueError("rank must be non-negative")
@property
def size(self) -> int:
"""Get the size of the block interval."""
return self.end - self.start
def contains(self, block_index: int) -> bool:
"""Check if the block interval contains the given block index."""
return self.start <= block_index < self.end
@dataclass
class PerformanceMetrics:
"""
Container for performance metrics.
This class holds timing and performance information for communication operations.
"""
dit_time: float
total_time: float
communication_time: float
buffer_allocation_time: float
def __post_init__(self):
"""Validate performance metrics."""
if self.dit_time < 0:
raise ValueError("dit_time must be non-negative")
if self.total_time < 0:
raise ValueError("total_time must be non-negative")
if self.communication_time < 0:
raise ValueError("communication_time must be non-negative")
if self.buffer_allocation_time < 0:
raise ValueError("buffer_allocation_time must be non-negative")
@property
def efficiency(self) -> float:
"""Calculate communication efficiency (computation time / total time)."""
if self.total_time == 0:
return 0.0
return (self.total_time - self.communication_time) / self.total_time