Spaces:
Running on Zero
Running on Zero
| """ | |
| Distributed communication abstraction layer. | |
| This module provides a high-level interface for distributed communication operations, | |
| encapsulating the low-level PyTorch distributed primitives. | |
| """ | |
| import torch | |
| import torch.distributed as dist | |
| from typing import List, Tuple, Optional, Any | |
| import logging | |
| import time | |
| from .utils import CommunicationTags, get_next_rank, get_prev_rank, CommunicationTimer | |
| from .data_containers import CommunicationConfig | |
| class DistributedCommunicator: | |
| """ | |
| High-level interface for distributed communication operations. | |
| This class encapsulates all distributed communication operations, providing | |
| a clean interface for sending and receiving tensors between ranks. | |
| """ | |
| def __init__(self, rank: int, world_size: int, device: torch.device, | |
| config: Optional[CommunicationConfig] = None): | |
| """ | |
| Initialize the distributed communicator. | |
| Args: | |
| rank: Current rank | |
| world_size: Total number of ranks | |
| device: GPU device for communication | |
| config: Communication configuration | |
| """ | |
| self.rank = rank | |
| self.world_size = world_size | |
| self.device = device | |
| self.config = config or CommunicationConfig() | |
| # Track outstanding operations | |
| self.outstanding_operations: List[Any] = [] | |
| # Setup logging | |
| self.logger = logging.getLogger(f"DistributedCommunicator_rank_{rank}") | |
| self.logger.propagate = False | |
| if not self.logger.handlers: | |
| handler = logging.StreamHandler() | |
| formatter = logging.Formatter( | |
| f'[Rank {rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| handler.setFormatter(formatter) | |
| self.logger.addHandler(handler) | |
| # Validate distributed is initialized | |
| if not dist.is_initialized(): | |
| raise RuntimeError("Distributed not initialized. Call init_distributed() first.") | |
| def send_tensor_async(self, tensor: torch.Tensor, dst: int, tag: int) -> Any: | |
| """ | |
| Asynchronously send a tensor to the specified destination. | |
| Args: | |
| tensor: Tensor to send | |
| dst: Destination rank | |
| tag: Communication tag | |
| Returns: | |
| Work object for the send operation | |
| """ | |
| if tensor.device != self.device: | |
| raise ValueError(f"Tensor device {tensor.device} doesn't match communicator device {self.device}") | |
| work = dist.isend(tensor, dst=dst, tag=tag) | |
| self.outstanding_operations.append(work) | |
| self.logger.debug(f"Started async send to rank {dst} with tag {tag}, tensor shape: {tensor.shape}") | |
| return work | |
| def recv_tensor(self, src: int, tag: int, shape: Tuple[int, ...], | |
| dtype: torch.dtype) -> torch.Tensor: | |
| """ | |
| Receive a tensor from the specified source. | |
| Args: | |
| src: Source rank | |
| tag: Communication tag | |
| shape: Expected tensor shape | |
| dtype: Expected tensor dtype | |
| Returns: | |
| Received tensor | |
| """ | |
| tensor = torch.empty(shape, dtype=dtype, device=self.device) | |
| with CommunicationTimer(f"recv_tensor from rank {src}", self.logger): | |
| dist.recv(tensor, src=src, tag=tag) | |
| self.logger.debug(f"Received tensor from rank {src} with tag {tag}, shape: {tensor.shape}") | |
| return tensor | |
| def send_header_and_tensor_async(self, header: torch.Tensor, tensor: torch.Tensor, | |
| dst: int, tag_header: int, tag_tensor: int) -> Tuple[Any, Any]: | |
| """ | |
| Asynchronously send a header and tensor pair. | |
| Args: | |
| header: Header tensor containing metadata | |
| tensor: Data tensor | |
| dst: Destination rank | |
| tag_header: Tag for header | |
| tag_tensor: Tag for tensor | |
| Returns: | |
| Tuple of (header_work, tensor_work) | |
| """ | |
| if header.device != self.device or tensor.device != self.device: | |
| raise ValueError("Header and tensor must be on the same device as communicator") | |
| header_work = dist.isend(header, dst=dst, tag=tag_header) | |
| tensor_work = dist.isend(tensor, dst=dst, tag=tag_tensor) | |
| self.outstanding_operations.extend([header_work, tensor_work]) | |
| self.logger.debug(f"Started async send of header+tensor to rank {dst}, " | |
| f"header shape: {header.shape}, tensor shape: {tensor.shape}") | |
| return header_work, tensor_work | |
| def recv_header_and_tensor(self, src: int, tag_header: int, tag_tensor: int, header_len: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Receive a header and tensor pair. | |
| Args: | |
| src: Source rank | |
| tag_header: Tag for header | |
| tag_tensor: Tag for tensor | |
| header_len: Length of header tensor to receive | |
| Returns: | |
| Tuple of (header, tensor) | |
| """ | |
| with CommunicationTimer(f"recv_header_and_tensor from rank {src}", self.logger): | |
| # First receive the header to get tensor shape (length can vary) | |
| header = torch.empty(header_len, dtype=torch.int64, device=self.device) | |
| dist.recv(header, src=src, tag=tag_header) | |
| # Parse header to get tensor shape | |
| chunk_idx, shape = self._parse_header(header) | |
| # Receive the tensor | |
| tensor = torch.empty(shape, dtype=torch.bfloat16, device=self.device) | |
| dist.recv(tensor, src=src, tag=tag_tensor) | |
| self.logger.debug(f"Received header+tensor from rank {src}, " | |
| f"header: {header.tolist()}, tensor shape: {tensor.shape}") | |
| return header, tensor | |
| def send_latent_data_async(self, chunk_idx: int, latents: torch.Tensor, | |
| original_latents: torch.Tensor, patched_x_shape: torch.Tensor, | |
| current_start: torch.Tensor, current_end: torch.Tensor, | |
| current_step: int) -> List[Any]: | |
| """ | |
| Asynchronously send all latent data components. | |
| Args: | |
| chunk_idx: Chunk index | |
| latents: Latent tensor | |
| original_latents: Original latent tensor | |
| patched_x_shape: Patched x shape tensor | |
| current_start: Current start indices | |
| current_end: Current end indices | |
| current_step: Current step | |
| Returns: | |
| List of work objects for all send operations | |
| """ | |
| dst = get_next_rank(self.rank, self.world_size) | |
| work_objects = [] | |
| # Create headers | |
| latent_header = self._create_header(chunk_idx, latents.shape) | |
| origin_header = self._create_header(chunk_idx, original_latents.shape) | |
| # Create start/end/step tensor | |
| start_end_step = torch.cat([ | |
| current_start, | |
| current_end, | |
| torch.tensor([current_step], dtype=torch.int64, device=self.device) | |
| ], dim=0) | |
| # Send all components asynchronously | |
| work_objects.extend(self.send_header_and_tensor_async( | |
| latent_header, latents, dst, CommunicationTags.LATENT_HDR, CommunicationTags.LATENT_PAY | |
| )) | |
| work_objects.extend(self.send_header_and_tensor_async( | |
| origin_header, original_latents, dst, | |
| CommunicationTags.LATENT_ORIGIN_HDR, CommunicationTags.LATENT_ORIGIN_PAY | |
| )) | |
| work_objects.append(self.send_tensor_async( | |
| patched_x_shape, dst, CommunicationTags.PATCHED_X_SHAPE | |
| )) | |
| work_objects.append(self.send_tensor_async( | |
| start_end_step, dst, CommunicationTags.START_END_STEP | |
| )) | |
| self.logger.debug(f"Started async send of latent data to rank {dst}, chunk_idx: {chunk_idx}") | |
| return work_objects | |
| def recv_latent_data_async(self, num_steps: int, buffer_manager) -> Tuple[int, torch.Tensor, torch.Tensor, | |
| torch.Tensor, torch.Tensor, int, torch.Tensor]: | |
| """ | |
| Asynchronously receive all latent data components. | |
| Args: | |
| num_steps: Number of denoising steps | |
| buffer_manager: Buffer manager for tensor allocation | |
| Returns: | |
| Tuple of (chunk_idx, latents, original_latents, current_start, current_end, current_step, patched_x_shape) | |
| """ | |
| src = get_prev_rank(self.rank, self.world_size) | |
| with CommunicationTimer(f"recv_latent_data_async from rank {src}", self.logger): | |
| # Receive latent header (length 4): [i, bsz, slen, cch] | |
| latent_header = buffer_manager.get_buffer((4,), torch.int64, "misc") | |
| dist.recv(latent_header, src=src, tag=CommunicationTags.LATENT_HDR) | |
| chunk_idx, latent_shape = self._parse_header(latent_header) | |
| # header no longer needed | |
| buffer_manager.return_buffer(latent_header, "misc") | |
| # Allocate or reuse buffer for latents: shape (bsz, slen, cch) | |
| latents = buffer_manager.get_buffer(tuple(latent_shape), torch.bfloat16, "latent") | |
| dist.recv(latents, src=src, tag=CommunicationTags.LATENT_PAY) | |
| # Receive original latent header (length 6): [i, bsz, cch, tlen, hh, ww] | |
| origin_header = buffer_manager.get_buffer((6,), torch.int64, "misc") | |
| dist.recv(origin_header, src=src, tag=CommunicationTags.LATENT_ORIGIN_HDR) | |
| _, origin_shape = self._parse_header(origin_header) | |
| # header no longer needed | |
| buffer_manager.return_buffer(origin_header, "misc") | |
| # Allocate or reuse buffer for original latents: shape (bsz, cch, tlen, hh, ww) | |
| original_latents = buffer_manager.get_buffer(tuple(origin_shape), torch.bfloat16, "origin") | |
| dist.recv(original_latents, src=src, tag=CommunicationTags.LATENT_ORIGIN_PAY) | |
| # Receive patched_x_shape (length 5, int64) | |
| patched_x_shape = buffer_manager.get_buffer((5,), torch.int64, "misc") | |
| dist.recv(patched_x_shape, src=src, tag=CommunicationTags.PATCHED_X_SHAPE) | |
| # Receive start_end_step (length 2*num_steps+1, int64) | |
| start_end_step = buffer_manager.get_buffer((2 * num_steps + 1,), torch.int64, "misc") | |
| dist.recv(start_end_step, src=src, tag=CommunicationTags.START_END_STEP) | |
| # Parse start/end/step into dedicated misc buffers, then release the combined vector | |
| current_start = buffer_manager.get_buffer((num_steps,), torch.int64, "misc") | |
| current_end = buffer_manager.get_buffer((num_steps,), torch.int64, "misc") | |
| current_start.copy_(start_end_step[:num_steps]) | |
| current_end.copy_(start_end_step[num_steps:-1]) | |
| current_step = int(start_end_step[-1].item()) | |
| # Release the temporary combined buffer | |
| buffer_manager.return_buffer(start_end_step, "misc") | |
| self.logger.debug(f"Received latent data from rank {src}, chunk_idx: {chunk_idx}") | |
| return chunk_idx, latents, original_latents, current_start, current_end, current_step, patched_x_shape | |
| def send_prompt_async(self, prompt: str, device: torch.device) -> List[Any]: | |
| work_objects = [] | |
| dst = get_next_rank(self.rank, self.world_size) | |
| # Encode to bytes | |
| encoded = prompt.encode("utf-8") | |
| data = torch.ByteTensor(list(encoded)).to(device) | |
| # Send length first | |
| length = torch.tensor([len(data)], dtype=torch.int64, device=data.device) | |
| work_objects.append(dist.isend(length, dst=dst, tag=CommunicationTags.UPDATED_PROMPT_LENGTH)) | |
| # Then send the content | |
| work_objects.append(dist.isend(data, dst=dst, tag=CommunicationTags.UPDATED_PROMPT)) | |
| return work_objects | |
| def recv_prompt_async(self) -> str: | |
| src = get_prev_rank(self.rank, self.world_size) | |
| # Receive length first | |
| length = torch.empty(1, dtype=torch.int64, device=self.device) | |
| dist.recv(length, src=src, tag=CommunicationTags.UPDATED_PROMPT_LENGTH) | |
| # Then receive the content | |
| prompt = torch.empty(length.item(), dtype=torch.uint8, device=self.device) | |
| dist.recv(prompt, src=src, tag=CommunicationTags.UPDATED_PROMPT) | |
| return bytes(prompt.cpu().tolist()).decode("utf-8") | |
| def broadcast_tensor(self, tensor: torch.Tensor, src: int) -> None: | |
| """ | |
| Broadcast a tensor from source to all ranks. | |
| Args: | |
| tensor: Tensor to broadcast | |
| src: Source rank | |
| """ | |
| with CommunicationTimer(f"broadcast_tensor from rank {src}", self.logger): | |
| dist.broadcast(tensor, src=src) | |
| self.logger.debug(f"Broadcasted tensor from rank {src}, shape: {tensor.shape}") | |
| def all_gather_tensors(self, tensor: torch.Tensor) -> List[torch.Tensor]: | |
| """ | |
| Gather tensors from all ranks. | |
| Args: | |
| tensor: Local tensor to gather | |
| Returns: | |
| List of tensors from all ranks | |
| """ | |
| with CommunicationTimer("all_gather_tensors", self.logger): | |
| gather_list = [torch.zeros_like(tensor) for _ in range(self.world_size)] | |
| dist.all_gather(gather_list, tensor) | |
| self.logger.debug(f"Gathered tensors from all ranks, local shape: {tensor.shape}") | |
| return gather_list | |
| def wait_for_outstanding(self, max_outstanding: Optional[int] = None) -> None: | |
| """ | |
| Wait for outstanding operations to complete. | |
| Args: | |
| max_outstanding: Maximum number of outstanding operations to keep | |
| """ | |
| max_outstanding = max_outstanding or self.config.max_outstanding | |
| while len(self.outstanding_operations) >= max_outstanding: | |
| if not self.outstanding_operations: | |
| break | |
| # Wait for the oldest operation | |
| oldest_operations = self.outstanding_operations.pop(0) | |
| # Handle both single work objects and lists of work objects | |
| if isinstance(oldest_operations, (list, tuple)): | |
| for work in oldest_operations: | |
| try: | |
| work.wait() | |
| except Exception as e: | |
| self.logger.error(f"Error waiting for outstanding operation: {e}") | |
| raise | |
| else: | |
| try: | |
| oldest_operations.wait() | |
| except Exception as e: | |
| self.logger.error(f"Error waiting for outstanding operation: {e}") | |
| raise | |
| self.logger.debug(f"Outstanding operations: {len(self.outstanding_operations)}") | |
| def barrier(self) -> None: | |
| """Synchronize all ranks.""" | |
| with CommunicationTimer("barrier", self.logger): | |
| dist.barrier() | |
| def _create_header(self, chunk_idx: int, shape: Tuple[int, ...]) -> torch.Tensor: | |
| """Create a header tensor for communication.""" | |
| header_data = [chunk_idx] + list(shape) | |
| return torch.tensor(header_data, dtype=torch.int64, device=self.device) | |
| def _parse_header(self, header: torch.Tensor) -> Tuple[int, Tuple[int, ...]]: | |
| """Parse a header tensor to extract metadata.""" | |
| header_list = header.tolist() | |
| chunk_idx = int(header_list[0]) | |
| shape = tuple(int(x) for x in header_list[1:]) | |
| return chunk_idx, shape | |
| def get_statistics(self) -> dict: | |
| """Get communication statistics.""" | |
| return { | |
| "rank": self.rank, | |
| "world_size": self.world_size, | |
| "outstanding_operations": len(self.outstanding_operations), | |
| "max_outstanding": self.config.max_outstanding, | |
| "device": str(self.device) | |
| } | |
| def print_statistics(self) -> None: | |
| """Print communication statistics.""" | |
| stats = self.get_statistics() | |
| self.logger.info("Distributed Communicator Statistics:") | |
| for key, value in stats.items(): | |
| self.logger.info(f" {key}: {value}") |