File size: 6,950 Bytes
7a87926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
"""
Distributed training utilities for multi-GPU training.

Supports both DDP (Distributed Data Parallel) and FSDP (Fully Sharded Data Parallel).
"""

import logging
import os
from typing import Optional
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

logger = logging.getLogger(__name__)


def setup_ddp(rank: int, world_size: int, backend: str = "nccl"):
    """
    Initialize distributed training environment.

    Args:
        rank: Process rank (0 to world_size-1)
        world_size: Total number of processes
        backend: Communication backend ('nccl' for GPU, 'gloo' for CPU)
    """
    os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
    os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12355")

    dist.init_process_group(
        backend=backend,
        rank=rank,
        world_size=world_size,
    )

    torch.cuda.set_device(rank)
    logger.info(f"DDP initialized: rank={rank}, world_size={world_size}, backend={backend}")


def cleanup_ddp():
    """Clean up distributed training environment."""
    if dist.is_initialized():
        dist.destroy_process_group()
        logger.info("DDP cleaned up")


def get_ddp_info() -> dict:
    """
    Get current DDP configuration.

    Returns:
        Dict with rank, world_size, is_initialized, etc.
    """
    return {
        "is_initialized": dist.is_initialized(),
        "rank": dist.get_rank() if dist.is_initialized() else 0,
        "world_size": dist.get_world_size() if dist.is_initialized() else 1,
        "backend": dist.get_backend() if dist.is_initialized() else None,
    }


def wrap_model_ddp(
    model: torch.nn.Module,
    device: str = "cuda",
    find_unused_parameters: bool = False,
    gradient_as_bucket_view: bool = True,
) -> torch.nn.Module:
    """
    Wrap model with DDP for distributed training.

    Args:
        model: Model to wrap
        device: Device to use
        find_unused_parameters: Whether to find unused parameters (slower but more flexible)
        gradient_as_bucket_view: Use gradient as bucket view for memory efficiency

    Returns:
        DDP-wrapped model
    """
    if not dist.is_initialized():
        logger.warning("DDP not initialized, returning unwrapped model")
        return model

    rank = dist.get_rank()
    if device == "cuda":
        torch.cuda.set_device(rank)
        device_id = rank
    else:
        device_id = None

    ddp_model = DDP(
        model,
        device_ids=[device_id] if device_id is not None else None,
        output_device=device_id,
        find_unused_parameters=find_unused_parameters,
        gradient_as_bucket_view=gradient_as_bucket_view,
    )

    logger.info(f"Model wrapped with DDP (rank={rank})")
    return ddp_model


def create_distributed_sampler(
    dataset,
    shuffle: bool = True,
    seed: int = 0,
) -> Optional[DistributedSampler]:
    """
    Create distributed sampler for dataset.

    Args:
        dataset: Dataset to sample from
        shuffle: Whether to shuffle
        seed: Random seed

    Returns:
        DistributedSampler if DDP is initialized, None otherwise
    """
    if not dist.is_initialized():
        return None

    sampler = DistributedSampler(
        dataset,
        num_replicas=dist.get_world_size(),
        rank=dist.get_rank(),
        shuffle=shuffle,
        seed=seed,
    )

    logger.info(f"Created DistributedSampler (rank={dist.get_rank()}/{dist.get_world_size()})")
    return sampler


def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
    """
    All-reduce tensor and compute mean across all processes.

    Args:
        tensor: Tensor to reduce

    Returns:
        Mean value across all processes
    """
    if not dist.is_initialized():
        return tensor

    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    tensor /= dist.get_world_size()
    return tensor


def save_checkpoint_ddp(
    model: torch.nn.Module,
    optimizer,
    scheduler,
    epoch: int,
    loss: float,
    checkpoint_path: str,
    is_main_process: bool = True,
):
    """
    Save checkpoint (only on main process to avoid conflicts).

    Args:
        model: Model to save
        optimizer: Optimizer state
        scheduler: Scheduler state
        epoch: Current epoch
        loss: Current loss
        checkpoint_path: Path to save checkpoint
        is_main_process: Whether this is the main process (rank 0)
    """
    if is_main_process:
        # Unwrap DDP model if needed
        if isinstance(model, DDP):
            model_state = model.module.state_dict()
        else:
            model_state = model.state_dict()

        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model_state,
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "loss": loss,
            },
            checkpoint_path,
        )
        logger.info(f"Saved checkpoint to {checkpoint_path}")

    # Synchronize all processes
    if dist.is_initialized():
        dist.barrier()


def load_checkpoint_ddp(
    model: torch.nn.Module,
    checkpoint_path: str,
    device: str = "cuda",
) -> dict:
    """
    Load checkpoint for distributed training.

    Args:
        model: Model to load into
        checkpoint_path: Path to checkpoint
        device: Device to load on

    Returns:
        Checkpoint dict
    """
    checkpoint = torch.load(checkpoint_path, map_location=device)

    # Handle DDP-wrapped models
    if isinstance(model, DDP):
        model.module.load_state_dict(checkpoint["model_state_dict"])
    else:
        model.load_state_dict(checkpoint["model_state_dict"])

    logger.info(f"Loaded checkpoint from {checkpoint_path}")
    return checkpoint


def run_distributed_training(
    rank: int,
    world_size: int,
    train_fn,
    *args,
    **kwargs,
):
    """
    Helper to run distributed training function.

    Args:
        rank: Process rank
        world_size: Total number of processes
        train_fn: Training function to run
        *args, **kwargs: Arguments to pass to train_fn
    """
    try:
        setup_ddp(rank, world_size)
        train_fn(rank, world_size, *args, **kwargs)
    finally:
        cleanup_ddp()


def launch_distributed_training(
    world_size: int,
    train_fn,
    *args,
    **kwargs,
):
    """
    Launch distributed training using torch.multiprocessing.

    Args:
        world_size: Number of GPUs to use
        train_fn: Training function (should accept rank and world_size as first args)
        *args, **kwargs: Additional arguments for train_fn
    """
    import torch.multiprocessing as mp

    mp.spawn(
        run_distributed_training,
        args=(world_size, train_fn) + args,
        nprocs=world_size,
        join=True,
    )