3d_model / ylff /utils /sequence_parallel.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Sequence Parallelism utilities for handling very long sequences.
Sequence parallelism splits sequences across multiple GPUs, allowing training
on sequences that don't fit in a single GPU's memory.
"""
import logging
from typing import Optional
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
def split_sequence_across_gpus(
sequence: torch.Tensor,
num_gpus: int,
dim: int = 1,
) -> list[torch.Tensor]:
"""
Split sequence tensor across multiple GPUs.
Args:
sequence: Input sequence tensor (B, L, ...)
num_gpus: Number of GPUs to split across
dim: Dimension to split along (usually sequence length)
Returns:
List of tensors, one per GPU
"""
if num_gpus == 1:
return [sequence]
seq_len = sequence.shape[dim]
chunk_size = seq_len // num_gpus
chunks = []
for i in range(num_gpus):
start_idx = i * chunk_size
end_idx = (i + 1) * chunk_size if i < num_gpus - 1 else seq_len
chunk = torch.narrow(sequence, dim, start_idx, end_idx - start_idx)
chunks.append(chunk)
return chunks
def gather_sequence_from_gpus(
chunks: list[torch.Tensor],
dim: int = 1,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Gather sequence chunks from multiple GPUs.
Args:
chunks: List of sequence chunks, one per GPU
dim: Dimension to concatenate along
device: Target device for gathered tensor
Returns:
Concatenated sequence tensor
"""
if len(chunks) == 1:
return chunks[0]
# Move all chunks to same device if needed
if device is not None:
chunks = [chunk.to(device) for chunk in chunks]
# Concatenate along sequence dimension
gathered = torch.cat(chunks, dim=dim)
return gathered
class SequenceParallelWrapper(nn.Module):
"""
Wrapper for sequence parallelism.
Splits input sequences across GPUs and gathers outputs.
"""
def __init__(
self,
model: nn.Module,
num_gpus: int = 1,
sequence_dim: int = 1,
):
"""
Initialize sequence parallel wrapper.
Args:
model: Model to wrap
num_gpus: Number of GPUs to use
sequence_dim: Dimension of sequence length
"""
super().__init__()
self.model = model
self.num_gpus = num_gpus
self.sequence_dim = sequence_dim
# Replicate model across GPUs
if num_gpus > 1:
self.models = nn.ModuleList(
[model.to(torch.device(f"cuda:{i}")) for i in range(num_gpus)]
)
else:
self.models = nn.ModuleList([model])
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass with sequence parallelism.
Args:
x: Input tensor (B, L, ...)
Returns:
Output tensor
"""
if self.num_gpus == 1:
return self.model(x)
# Split sequence across GPUs
chunks = split_sequence_across_gpus(x, self.num_gpus, dim=self.sequence_dim)
# Process each chunk on its GPU
outputs = []
for i, (chunk, model) in enumerate(zip(chunks, self.models)):
chunk = chunk.to(torch.device(f"cuda:{i}"))
output = model(chunk)
outputs.append(output)
# Gather outputs
result = gather_sequence_from_gpus(outputs, dim=self.sequence_dim, device=x.device)
return result
def enable_sequence_parallelism(
model: nn.Module,
num_gpus: int = 1,
sequence_dim: int = 1,
) -> nn.Module:
"""
Enable sequence parallelism for a model.
Args:
model: Model to enable sequence parallelism for
num_gpus: Number of GPUs to use
sequence_dim: Dimension of sequence length
Returns:
Model wrapped with sequence parallelism
"""
if num_gpus <= 1:
logger.warning("Sequence parallelism requires multiple GPUs")
return model
logger.info(f"Enabling sequence parallelism: {num_gpus} GPUs, " f"sequence_dim={sequence_dim}")
return SequenceParallelWrapper(model, num_gpus, sequence_dim)