""" Sequence Parallelism utilities for handling very long sequences. Sequence parallelism splits sequences across multiple GPUs, allowing training on sequences that don't fit in a single GPU's memory. """ import logging from typing import Optional import torch import torch.nn as nn logger = logging.getLogger(__name__) def split_sequence_across_gpus( sequence: torch.Tensor, num_gpus: int, dim: int = 1, ) -> list[torch.Tensor]: """ Split sequence tensor across multiple GPUs. Args: sequence: Input sequence tensor (B, L, ...) num_gpus: Number of GPUs to split across dim: Dimension to split along (usually sequence length) Returns: List of tensors, one per GPU """ if num_gpus == 1: return [sequence] seq_len = sequence.shape[dim] chunk_size = seq_len // num_gpus chunks = [] for i in range(num_gpus): start_idx = i * chunk_size end_idx = (i + 1) * chunk_size if i < num_gpus - 1 else seq_len chunk = torch.narrow(sequence, dim, start_idx, end_idx - start_idx) chunks.append(chunk) return chunks def gather_sequence_from_gpus( chunks: list[torch.Tensor], dim: int = 1, device: Optional[torch.device] = None, ) -> torch.Tensor: """ Gather sequence chunks from multiple GPUs. Args: chunks: List of sequence chunks, one per GPU dim: Dimension to concatenate along device: Target device for gathered tensor Returns: Concatenated sequence tensor """ if len(chunks) == 1: return chunks[0] # Move all chunks to same device if needed if device is not None: chunks = [chunk.to(device) for chunk in chunks] # Concatenate along sequence dimension gathered = torch.cat(chunks, dim=dim) return gathered class SequenceParallelWrapper(nn.Module): """ Wrapper for sequence parallelism. Splits input sequences across GPUs and gathers outputs. """ def __init__( self, model: nn.Module, num_gpus: int = 1, sequence_dim: int = 1, ): """ Initialize sequence parallel wrapper. Args: model: Model to wrap num_gpus: Number of GPUs to use sequence_dim: Dimension of sequence length """ super().__init__() self.model = model self.num_gpus = num_gpus self.sequence_dim = sequence_dim # Replicate model across GPUs if num_gpus > 1: self.models = nn.ModuleList( [model.to(torch.device(f"cuda:{i}")) for i in range(num_gpus)] ) else: self.models = nn.ModuleList([model]) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass with sequence parallelism. Args: x: Input tensor (B, L, ...) Returns: Output tensor """ if self.num_gpus == 1: return self.model(x) # Split sequence across GPUs chunks = split_sequence_across_gpus(x, self.num_gpus, dim=self.sequence_dim) # Process each chunk on its GPU outputs = [] for i, (chunk, model) in enumerate(zip(chunks, self.models)): chunk = chunk.to(torch.device(f"cuda:{i}")) output = model(chunk) outputs.append(output) # Gather outputs result = gather_sequence_from_gpus(outputs, dim=self.sequence_dim, device=x.device) return result def enable_sequence_parallelism( model: nn.Module, num_gpus: int = 1, sequence_dim: int = 1, ) -> nn.Module: """ Enable sequence parallelism for a model. Args: model: Model to enable sequence parallelism for num_gpus: Number of GPUs to use sequence_dim: Dimension of sequence length Returns: Model wrapped with sequence parallelism """ if num_gpus <= 1: logger.warning("Sequence parallelism requires multiple GPUs") return model logger.info(f"Enabling sequence parallelism: {num_gpus} GPUs, " f"sequence_dim={sequence_dim}") return SequenceParallelWrapper(model, num_gpus, sequence_dim)