File size: 4,228 Bytes
7a87926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
Sequence Parallelism utilities for handling very long sequences.

Sequence parallelism splits sequences across multiple GPUs, allowing training
on sequences that don't fit in a single GPU's memory.
"""

import logging
from typing import Optional
import torch
import torch.nn as nn

logger = logging.getLogger(__name__)


def split_sequence_across_gpus(
    sequence: torch.Tensor,
    num_gpus: int,
    dim: int = 1,
) -> list[torch.Tensor]:
    """
    Split sequence tensor across multiple GPUs.

    Args:
        sequence: Input sequence tensor (B, L, ...)
        num_gpus: Number of GPUs to split across
        dim: Dimension to split along (usually sequence length)

    Returns:
        List of tensors, one per GPU
    """
    if num_gpus == 1:
        return [sequence]

    seq_len = sequence.shape[dim]
    chunk_size = seq_len // num_gpus

    chunks = []
    for i in range(num_gpus):
        start_idx = i * chunk_size
        end_idx = (i + 1) * chunk_size if i < num_gpus - 1 else seq_len
        chunk = torch.narrow(sequence, dim, start_idx, end_idx - start_idx)
        chunks.append(chunk)

    return chunks


def gather_sequence_from_gpus(
    chunks: list[torch.Tensor],
    dim: int = 1,
    device: Optional[torch.device] = None,
) -> torch.Tensor:
    """
    Gather sequence chunks from multiple GPUs.

    Args:
        chunks: List of sequence chunks, one per GPU
        dim: Dimension to concatenate along
        device: Target device for gathered tensor

    Returns:
        Concatenated sequence tensor
    """
    if len(chunks) == 1:
        return chunks[0]

    # Move all chunks to same device if needed
    if device is not None:
        chunks = [chunk.to(device) for chunk in chunks]

    # Concatenate along sequence dimension
    gathered = torch.cat(chunks, dim=dim)

    return gathered


class SequenceParallelWrapper(nn.Module):
    """
    Wrapper for sequence parallelism.

    Splits input sequences across GPUs and gathers outputs.
    """

    def __init__(
        self,
        model: nn.Module,
        num_gpus: int = 1,
        sequence_dim: int = 1,
    ):
        """
        Initialize sequence parallel wrapper.

        Args:
            model: Model to wrap
            num_gpus: Number of GPUs to use
            sequence_dim: Dimension of sequence length
        """
        super().__init__()
        self.model = model
        self.num_gpus = num_gpus
        self.sequence_dim = sequence_dim

        # Replicate model across GPUs
        if num_gpus > 1:
            self.models = nn.ModuleList(
                [model.to(torch.device(f"cuda:{i}")) for i in range(num_gpus)]
            )
        else:
            self.models = nn.ModuleList([model])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass with sequence parallelism.

        Args:
            x: Input tensor (B, L, ...)

        Returns:
            Output tensor
        """
        if self.num_gpus == 1:
            return self.model(x)

        # Split sequence across GPUs
        chunks = split_sequence_across_gpus(x, self.num_gpus, dim=self.sequence_dim)

        # Process each chunk on its GPU
        outputs = []
        for i, (chunk, model) in enumerate(zip(chunks, self.models)):
            chunk = chunk.to(torch.device(f"cuda:{i}"))
            output = model(chunk)
            outputs.append(output)

        # Gather outputs
        result = gather_sequence_from_gpus(outputs, dim=self.sequence_dim, device=x.device)

        return result


def enable_sequence_parallelism(
    model: nn.Module,
    num_gpus: int = 1,
    sequence_dim: int = 1,
) -> nn.Module:
    """
    Enable sequence parallelism for a model.

    Args:
        model: Model to enable sequence parallelism for
        num_gpus: Number of GPUs to use
        sequence_dim: Dimension of sequence length

    Returns:
        Model wrapped with sequence parallelism
    """
    if num_gpus <= 1:
        logger.warning("Sequence parallelism requires multiple GPUs")
        return model

    logger.info(f"Enabling sequence parallelism: {num_gpus} GPUs, " f"sequence_dim={sequence_dim}")

    return SequenceParallelWrapper(model, num_gpus, sequence_dim)