File size: 4,228 Bytes
7a87926 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
"""
Sequence Parallelism utilities for handling very long sequences.
Sequence parallelism splits sequences across multiple GPUs, allowing training
on sequences that don't fit in a single GPU's memory.
"""
import logging
from typing import Optional
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
def split_sequence_across_gpus(
sequence: torch.Tensor,
num_gpus: int,
dim: int = 1,
) -> list[torch.Tensor]:
"""
Split sequence tensor across multiple GPUs.
Args:
sequence: Input sequence tensor (B, L, ...)
num_gpus: Number of GPUs to split across
dim: Dimension to split along (usually sequence length)
Returns:
List of tensors, one per GPU
"""
if num_gpus == 1:
return [sequence]
seq_len = sequence.shape[dim]
chunk_size = seq_len // num_gpus
chunks = []
for i in range(num_gpus):
start_idx = i * chunk_size
end_idx = (i + 1) * chunk_size if i < num_gpus - 1 else seq_len
chunk = torch.narrow(sequence, dim, start_idx, end_idx - start_idx)
chunks.append(chunk)
return chunks
def gather_sequence_from_gpus(
chunks: list[torch.Tensor],
dim: int = 1,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Gather sequence chunks from multiple GPUs.
Args:
chunks: List of sequence chunks, one per GPU
dim: Dimension to concatenate along
device: Target device for gathered tensor
Returns:
Concatenated sequence tensor
"""
if len(chunks) == 1:
return chunks[0]
# Move all chunks to same device if needed
if device is not None:
chunks = [chunk.to(device) for chunk in chunks]
# Concatenate along sequence dimension
gathered = torch.cat(chunks, dim=dim)
return gathered
class SequenceParallelWrapper(nn.Module):
"""
Wrapper for sequence parallelism.
Splits input sequences across GPUs and gathers outputs.
"""
def __init__(
self,
model: nn.Module,
num_gpus: int = 1,
sequence_dim: int = 1,
):
"""
Initialize sequence parallel wrapper.
Args:
model: Model to wrap
num_gpus: Number of GPUs to use
sequence_dim: Dimension of sequence length
"""
super().__init__()
self.model = model
self.num_gpus = num_gpus
self.sequence_dim = sequence_dim
# Replicate model across GPUs
if num_gpus > 1:
self.models = nn.ModuleList(
[model.to(torch.device(f"cuda:{i}")) for i in range(num_gpus)]
)
else:
self.models = nn.ModuleList([model])
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass with sequence parallelism.
Args:
x: Input tensor (B, L, ...)
Returns:
Output tensor
"""
if self.num_gpus == 1:
return self.model(x)
# Split sequence across GPUs
chunks = split_sequence_across_gpus(x, self.num_gpus, dim=self.sequence_dim)
# Process each chunk on its GPU
outputs = []
for i, (chunk, model) in enumerate(zip(chunks, self.models)):
chunk = chunk.to(torch.device(f"cuda:{i}"))
output = model(chunk)
outputs.append(output)
# Gather outputs
result = gather_sequence_from_gpus(outputs, dim=self.sequence_dim, device=x.device)
return result
def enable_sequence_parallelism(
model: nn.Module,
num_gpus: int = 1,
sequence_dim: int = 1,
) -> nn.Module:
"""
Enable sequence parallelism for a model.
Args:
model: Model to enable sequence parallelism for
num_gpus: Number of GPUs to use
sequence_dim: Dimension of sequence length
Returns:
Model wrapped with sequence parallelism
"""
if num_gpus <= 1:
logger.warning("Sequence parallelism requires multiple GPUs")
return model
logger.info(f"Enabling sequence parallelism: {num_gpus} GPUs, " f"sequence_dim={sequence_dim}")
return SequenceParallelWrapper(model, num_gpus, sequence_dim)
|