Spaces:
Running on Zero
Running on Zero
| """ | |
| Data containers for communication operations. | |
| This module defines data structures used for communication between distributed ranks. | |
| """ | |
| from dataclasses import dataclass | |
| from typing import Optional, List, Dict, Any | |
| import torch | |
| class LatentData: | |
| """ | |
| Container for latent data and related information. | |
| This class encapsulates all the data that needs to be transferred between ranks | |
| during the inference pipeline. | |
| """ | |
| chunk_idx: int | |
| latents: torch.Tensor | |
| original_latents: torch.Tensor | |
| current_start: torch.Tensor | |
| current_end: torch.Tensor | |
| current_step: int | |
| patched_x_shape: torch.Tensor | |
| def __post_init__(self): | |
| """Validate tensor shapes and types after initialization.""" | |
| if not isinstance(self.latents, torch.Tensor): | |
| raise TypeError("latents must be a torch.Tensor") | |
| if not isinstance(self.original_latents, torch.Tensor): | |
| raise TypeError("original_latents must be a torch.Tensor") | |
| if not isinstance(self.current_start, torch.Tensor): | |
| raise TypeError("current_start must be a torch.Tensor") | |
| if not isinstance(self.current_end, torch.Tensor): | |
| raise TypeError("current_end must be a torch.Tensor") | |
| if not isinstance(self.patched_x_shape, torch.Tensor): | |
| raise TypeError("patched_x_shape must be a torch.Tensor") | |
| class KVCacheData: | |
| """ | |
| Container for KV cache data. | |
| This class encapsulates key-value cache information for transformer blocks. | |
| """ | |
| block_index: int | |
| k_cache: torch.Tensor | |
| v_cache: torch.Tensor | |
| global_end_index: torch.Tensor | |
| local_end_index: torch.Tensor | |
| def __post_init__(self): | |
| """Validate tensor shapes and types after initialization.""" | |
| if not isinstance(self.k_cache, torch.Tensor): | |
| raise TypeError("k_cache must be a torch.Tensor") | |
| if not isinstance(self.v_cache, torch.Tensor): | |
| raise TypeError("v_cache must be a torch.Tensor") | |
| if not isinstance(self.global_end_index, torch.Tensor): | |
| raise TypeError("global_end_index must be a torch.Tensor") | |
| if not isinstance(self.local_end_index, torch.Tensor): | |
| raise TypeError("local_end_index must be a torch.Tensor") | |
| class CommunicationConfig: | |
| """ | |
| Configuration for communication operations. | |
| This class holds configuration parameters for distributed communication. | |
| """ | |
| max_outstanding: int = 1 | |
| buffer_pool_size: int = 10 | |
| enable_buffer_reuse: bool = True | |
| communication_timeout: float = 30.0 | |
| enable_async_communication: bool = True | |
| def __post_init__(self): | |
| """Validate configuration parameters.""" | |
| if self.max_outstanding < 1: | |
| raise ValueError("max_outstanding must be at least 1") | |
| if self.buffer_pool_size < 1: | |
| raise ValueError("buffer_pool_size must be at least 1") | |
| if self.communication_timeout <= 0: | |
| raise ValueError("communication_timeout must be positive") | |
| class BlockInterval: | |
| """ | |
| Container for block interval information. | |
| This class represents a block interval [start, end) for a specific rank. | |
| """ | |
| start: int | |
| end: int | |
| rank: int | |
| def __post_init__(self): | |
| """Validate block interval parameters.""" | |
| if self.start < 0: | |
| raise ValueError("start must be non-negative") | |
| if self.end <= self.start: | |
| raise ValueError("end must be greater than start") | |
| if self.rank < 0: | |
| raise ValueError("rank must be non-negative") | |
| def size(self) -> int: | |
| """Get the size of the block interval.""" | |
| return self.end - self.start | |
| def contains(self, block_index: int) -> bool: | |
| """Check if the block interval contains the given block index.""" | |
| return self.start <= block_index < self.end | |
| class PerformanceMetrics: | |
| """ | |
| Container for performance metrics. | |
| This class holds timing and performance information for communication operations. | |
| """ | |
| dit_time: float | |
| total_time: float | |
| communication_time: float | |
| buffer_allocation_time: float | |
| def __post_init__(self): | |
| """Validate performance metrics.""" | |
| if self.dit_time < 0: | |
| raise ValueError("dit_time must be non-negative") | |
| if self.total_time < 0: | |
| raise ValueError("total_time must be non-negative") | |
| if self.communication_time < 0: | |
| raise ValueError("communication_time must be non-negative") | |
| if self.buffer_allocation_time < 0: | |
| raise ValueError("buffer_allocation_time must be non-negative") | |
| def efficiency(self) -> float: | |
| """Calculate communication efficiency (computation time / total time).""" | |
| if self.total_time == 0: | |
| return 0.0 | |
| return (self.total_time - self.communication_time) / self.total_time | |