|
|
""" |
|
|
Sequence Parallelism utilities for handling very long sequences. |
|
|
|
|
|
Sequence parallelism splits sequences across multiple GPUs, allowing training |
|
|
on sequences that don't fit in a single GPU's memory. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from typing import Optional |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def split_sequence_across_gpus( |
|
|
sequence: torch.Tensor, |
|
|
num_gpus: int, |
|
|
dim: int = 1, |
|
|
) -> list[torch.Tensor]: |
|
|
""" |
|
|
Split sequence tensor across multiple GPUs. |
|
|
|
|
|
Args: |
|
|
sequence: Input sequence tensor (B, L, ...) |
|
|
num_gpus: Number of GPUs to split across |
|
|
dim: Dimension to split along (usually sequence length) |
|
|
|
|
|
Returns: |
|
|
List of tensors, one per GPU |
|
|
""" |
|
|
if num_gpus == 1: |
|
|
return [sequence] |
|
|
|
|
|
seq_len = sequence.shape[dim] |
|
|
chunk_size = seq_len // num_gpus |
|
|
|
|
|
chunks = [] |
|
|
for i in range(num_gpus): |
|
|
start_idx = i * chunk_size |
|
|
end_idx = (i + 1) * chunk_size if i < num_gpus - 1 else seq_len |
|
|
chunk = torch.narrow(sequence, dim, start_idx, end_idx - start_idx) |
|
|
chunks.append(chunk) |
|
|
|
|
|
return chunks |
|
|
|
|
|
|
|
|
def gather_sequence_from_gpus( |
|
|
chunks: list[torch.Tensor], |
|
|
dim: int = 1, |
|
|
device: Optional[torch.device] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Gather sequence chunks from multiple GPUs. |
|
|
|
|
|
Args: |
|
|
chunks: List of sequence chunks, one per GPU |
|
|
dim: Dimension to concatenate along |
|
|
device: Target device for gathered tensor |
|
|
|
|
|
Returns: |
|
|
Concatenated sequence tensor |
|
|
""" |
|
|
if len(chunks) == 1: |
|
|
return chunks[0] |
|
|
|
|
|
|
|
|
if device is not None: |
|
|
chunks = [chunk.to(device) for chunk in chunks] |
|
|
|
|
|
|
|
|
gathered = torch.cat(chunks, dim=dim) |
|
|
|
|
|
return gathered |
|
|
|
|
|
|
|
|
class SequenceParallelWrapper(nn.Module): |
|
|
""" |
|
|
Wrapper for sequence parallelism. |
|
|
|
|
|
Splits input sequences across GPUs and gathers outputs. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: nn.Module, |
|
|
num_gpus: int = 1, |
|
|
sequence_dim: int = 1, |
|
|
): |
|
|
""" |
|
|
Initialize sequence parallel wrapper. |
|
|
|
|
|
Args: |
|
|
model: Model to wrap |
|
|
num_gpus: Number of GPUs to use |
|
|
sequence_dim: Dimension of sequence length |
|
|
""" |
|
|
super().__init__() |
|
|
self.model = model |
|
|
self.num_gpus = num_gpus |
|
|
self.sequence_dim = sequence_dim |
|
|
|
|
|
|
|
|
if num_gpus > 1: |
|
|
self.models = nn.ModuleList( |
|
|
[model.to(torch.device(f"cuda:{i}")) for i in range(num_gpus)] |
|
|
) |
|
|
else: |
|
|
self.models = nn.ModuleList([model]) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass with sequence parallelism. |
|
|
|
|
|
Args: |
|
|
x: Input tensor (B, L, ...) |
|
|
|
|
|
Returns: |
|
|
Output tensor |
|
|
""" |
|
|
if self.num_gpus == 1: |
|
|
return self.model(x) |
|
|
|
|
|
|
|
|
chunks = split_sequence_across_gpus(x, self.num_gpus, dim=self.sequence_dim) |
|
|
|
|
|
|
|
|
outputs = [] |
|
|
for i, (chunk, model) in enumerate(zip(chunks, self.models)): |
|
|
chunk = chunk.to(torch.device(f"cuda:{i}")) |
|
|
output = model(chunk) |
|
|
outputs.append(output) |
|
|
|
|
|
|
|
|
result = gather_sequence_from_gpus(outputs, dim=self.sequence_dim, device=x.device) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def enable_sequence_parallelism( |
|
|
model: nn.Module, |
|
|
num_gpus: int = 1, |
|
|
sequence_dim: int = 1, |
|
|
) -> nn.Module: |
|
|
""" |
|
|
Enable sequence parallelism for a model. |
|
|
|
|
|
Args: |
|
|
model: Model to enable sequence parallelism for |
|
|
num_gpus: Number of GPUs to use |
|
|
sequence_dim: Dimension of sequence length |
|
|
|
|
|
Returns: |
|
|
Model wrapped with sequence parallelism |
|
|
""" |
|
|
if num_gpus <= 1: |
|
|
logger.warning("Sequence parallelism requires multiple GPUs") |
|
|
return model |
|
|
|
|
|
logger.info(f"Enabling sequence parallelism: {num_gpus} GPUs, " f"sequence_dim={sequence_dim}") |
|
|
|
|
|
return SequenceParallelWrapper(model, num_gpus, sequence_dim) |
|
|
|