|
|
"""
|
|
|
4D Parallelism Distributed Training Infrastructure
|
|
|
Implements Data, Tensor, Pipeline, and Expert Parallelism with Sequence Parallelism
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import torch.distributed as dist
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
from typing import Dict, List, Tuple, Optional, Any, Union
|
|
|
from dataclasses import dataclass, field
|
|
|
from enum import Enum
|
|
|
import os
|
|
|
import math
|
|
|
import numpy as np
|
|
|
from contextlib import contextmanager
|
|
|
import logging
|
|
|
|
|
|
|
|
|
try:
|
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
|
from torch.distributed.fsdp import CPUOffload, BackwardPrefetch, ShardingStrategy, MixedPrecision
|
|
|
FSDP_AVAILABLE = True
|
|
|
except ImportError:
|
|
|
FSDP_AVAILABLE = False
|
|
|
logger = logging.getLogger(__name__)
|
|
|
logger.warning("FSDP not available, falling back to DDP")
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class ParallelismType(Enum):
|
|
|
"""Types of parallelism"""
|
|
|
DATA = "data"
|
|
|
TENSOR = "tensor"
|
|
|
PIPELINE = "pipeline"
|
|
|
EXPERT = "expert"
|
|
|
SEQUENCE = "sequence"
|
|
|
ZERO = "zero"
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class DistributedConfig:
|
|
|
"""Configuration for distributed training"""
|
|
|
|
|
|
data_parallel_size: int = 1
|
|
|
tensor_parallel_size: int = 1
|
|
|
pipeline_parallel_size: int = 1
|
|
|
expert_parallel_size: int = 1
|
|
|
|
|
|
|
|
|
enable_sequence_parallel: bool = True
|
|
|
sequence_parallel_size: int = 1
|
|
|
max_sequence_length: int = 8192
|
|
|
|
|
|
|
|
|
zero_stage: int = 3
|
|
|
offload_optimizer: bool = True
|
|
|
offload_param: bool = False
|
|
|
|
|
|
|
|
|
activation_checkpointing: bool = True
|
|
|
cpu_offload: bool = True
|
|
|
gradient_compression: bool = True
|
|
|
mixed_precision: str = "bf16"
|
|
|
|
|
|
|
|
|
backend: str = "nccl"
|
|
|
gradient_as_bucket_view: bool = True
|
|
|
find_unused_parameters: bool = False
|
|
|
broadcast_buffers: bool = True
|
|
|
|
|
|
|
|
|
num_microbatches: int = 4
|
|
|
pipeline_schedule: str = "1f1b"
|
|
|
|
|
|
|
|
|
expert_capacity_factor: float = 1.25
|
|
|
expert_drop_policy: str = "probs"
|
|
|
|
|
|
|
|
|
gradient_accumulation_steps: int = 1
|
|
|
gradient_clipping: float = 1.0
|
|
|
|
|
|
|
|
|
checkpoint_dir: str = "./checkpoints"
|
|
|
checkpoint_interval: int = 1000
|
|
|
use_sharded_checkpointing: bool = True
|
|
|
|
|
|
|
|
|
class ProcessGroupManager:
|
|
|
"""Manage process groups for different parallelism types"""
|
|
|
|
|
|
def __init__(self, config: DistributedConfig):
|
|
|
self.config = config
|
|
|
self.world_size = dist.get_world_size()
|
|
|
self.rank = dist.get_rank()
|
|
|
|
|
|
|
|
|
self.total_parallel_size = (
|
|
|
config.data_parallel_size *
|
|
|
config.tensor_parallel_size *
|
|
|
config.pipeline_parallel_size
|
|
|
)
|
|
|
|
|
|
|
|
|
self.process_groups = {}
|
|
|
self._initialize_process_groups()
|
|
|
|
|
|
def _initialize_process_groups(self):
|
|
|
"""Initialize all process groups"""
|
|
|
|
|
|
dp_ranks = self._get_data_parallel_ranks()
|
|
|
self.process_groups[ParallelismType.DATA] = dist.new_group(dp_ranks)
|
|
|
|
|
|
|
|
|
tp_ranks = self._get_tensor_parallel_ranks()
|
|
|
self.process_groups[ParallelismType.TENSOR] = dist.new_group(tp_ranks)
|
|
|
|
|
|
|
|
|
pp_ranks = self._get_pipeline_parallel_ranks()
|
|
|
self.process_groups[ParallelismType.PIPELINE] = dist.new_group(pp_ranks)
|
|
|
|
|
|
|
|
|
if self.config.enable_sequence_parallel:
|
|
|
self.process_groups[ParallelismType.SEQUENCE] = self.process_groups[ParallelismType.TENSOR]
|
|
|
|
|
|
def _get_data_parallel_ranks(self) -> List[int]:
|
|
|
"""Get ranks for data parallel group"""
|
|
|
dp_size = self.config.data_parallel_size
|
|
|
tp_size = self.config.tensor_parallel_size
|
|
|
pp_size = self.config.pipeline_parallel_size
|
|
|
|
|
|
|
|
|
dp_group_idx = self.rank // (tp_size * pp_size)
|
|
|
ranks = []
|
|
|
for r in range(self.world_size):
|
|
|
if r // (tp_size * pp_size) == dp_group_idx:
|
|
|
ranks.append(r)
|
|
|
return ranks
|
|
|
|
|
|
def _get_tensor_parallel_ranks(self) -> List[int]:
|
|
|
"""Get ranks for tensor parallel group"""
|
|
|
tp_size = self.config.tensor_parallel_size
|
|
|
pp_size = self.config.pipeline_parallel_size
|
|
|
|
|
|
tp_group_idx = (self.rank // pp_size) % tp_size
|
|
|
ranks = []
|
|
|
for r in range(self.world_size):
|
|
|
if (r // pp_size) % tp_size == tp_group_idx:
|
|
|
ranks.append(r)
|
|
|
return ranks
|
|
|
|
|
|
def _get_pipeline_parallel_ranks(self) -> List[int]:
|
|
|
"""Get ranks for pipeline parallel group"""
|
|
|
pp_size = self.config.pipeline_parallel_size
|
|
|
|
|
|
pp_group_idx = self.rank % pp_size
|
|
|
ranks = []
|
|
|
for r in range(self.world_size):
|
|
|
if r % pp_size == pp_group_idx:
|
|
|
ranks.append(r)
|
|
|
return ranks
|
|
|
|
|
|
def get_group(self, parallelism_type: ParallelismType):
|
|
|
"""Get process group for given parallelism type"""
|
|
|
return self.process_groups.get(parallelism_type)
|
|
|
|
|
|
|
|
|
class MemoryOptimizer:
|
|
|
"""Memory optimization utilities"""
|
|
|
|
|
|
def __init__(self, config: DistributedConfig):
|
|
|
self.config = config
|
|
|
|
|
|
@contextmanager
|
|
|
def activation_checkpointing_context(self):
|
|
|
"""Context manager for activation checkpointing"""
|
|
|
if self.config.activation_checkpointing:
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
yield
|
|
|
torch.cuda.empty_cache()
|
|
|
else:
|
|
|
yield
|
|
|
|
|
|
def apply_gradient_compression(self, gradients: torch.Tensor) -> torch.Tensor:
|
|
|
"""Apply gradient compression"""
|
|
|
if not self.config.gradient_compression:
|
|
|
return gradients
|
|
|
|
|
|
|
|
|
|
|
|
shape = gradients.shape
|
|
|
flattened = gradients.flatten()
|
|
|
|
|
|
|
|
|
k = int(flattened.numel() * 0.1)
|
|
|
topk_vals, topk_idx = torch.topk(flattened.abs(), k)
|
|
|
|
|
|
compressed = torch.zeros_like(flattened)
|
|
|
compressed[topk_idx] = flattened[topk_idx]
|
|
|
|
|
|
return compressed.reshape(shape)
|
|
|
|
|
|
def setup_mixed_precision(self):
|
|
|
"""Setup mixed precision training"""
|
|
|
if self.config.mixed_precision == "fp16":
|
|
|
return torch.cuda.amp.GradScaler()
|
|
|
elif self.config.mixed_precision == "bf16":
|
|
|
|
|
|
return None
|
|
|
elif self.config.mixed_precision == "fp8":
|
|
|
|
|
|
logger.warning("FP8 not fully supported, falling back to BF16")
|
|
|
return None
|
|
|
return None
|
|
|
|
|
|
|
|
|
class FourDParallelModel(nn.Module):
|
|
|
"""Model wrapper for 4D parallelism"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
model: nn.Module,
|
|
|
config: DistributedConfig
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.model = model
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
self.pg_manager = ProcessGroupManager(config)
|
|
|
|
|
|
|
|
|
self.memory_optimizer = MemoryOptimizer(config)
|
|
|
|
|
|
|
|
|
self._apply_parallelism()
|
|
|
|
|
|
def _apply_parallelism(self):
|
|
|
"""Apply all parallelism types to model"""
|
|
|
|
|
|
if self.config.tensor_parallel_size > 1:
|
|
|
self._apply_tensor_parallelism()
|
|
|
|
|
|
|
|
|
if self.config.pipeline_parallel_size > 1:
|
|
|
self._apply_pipeline_parallelism()
|
|
|
|
|
|
|
|
|
if self.config.data_parallel_size > 1:
|
|
|
self._apply_data_parallelism()
|
|
|
|
|
|
|
|
|
if self.config.enable_sequence_parallel:
|
|
|
self._apply_sequence_parallelism()
|
|
|
|
|
|
def _apply_tensor_parallelism(self):
|
|
|
"""Apply tensor parallelism to model layers"""
|
|
|
|
|
|
for name, module in self.model.named_modules():
|
|
|
if isinstance(module, nn.Linear):
|
|
|
|
|
|
tp_linear = TensorParallelLinear(
|
|
|
module.in_features,
|
|
|
module.out_features,
|
|
|
bias=module.bias is not None,
|
|
|
tensor_parallel_size=self.config.tensor_parallel_size
|
|
|
)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
tp_linear.weight.copy_(module.weight)
|
|
|
if module.bias is not None:
|
|
|
tp_linear.bias.copy_(module.bias)
|
|
|
|
|
|
|
|
|
parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
|
|
|
child_name = name.rsplit('.', 1)[1] if '.' in name else name
|
|
|
parent = self.model
|
|
|
if parent_name:
|
|
|
for part in parent_name.split('.'):
|
|
|
parent = getattr(parent, part)
|
|
|
setattr(parent, child_name, tp_linear)
|
|
|
|
|
|
def _apply_pipeline_parallelism(self):
|
|
|
"""Apply pipeline parallelism"""
|
|
|
|
|
|
layers = list(self.model.children())
|
|
|
num_layers = len(layers)
|
|
|
layers_per_stage = num_layers // self.config.pipeline_parallel_size
|
|
|
|
|
|
|
|
|
stages = []
|
|
|
for i in range(self.config.pipeline_parallel_size):
|
|
|
start = i * layers_per_stage
|
|
|
end = start + layers_per_stage if i < self.config.pipeline_parallel_size - 1 else num_layers
|
|
|
stage = nn.Sequential(*layers[start:end])
|
|
|
stages.append(stage)
|
|
|
|
|
|
|
|
|
pp_rank = dist.get_rank() % self.config.pipeline_parallel_size
|
|
|
self.model = stages[pp_rank]
|
|
|
|
|
|
def _apply_data_parallelism(self):
|
|
|
"""Apply data parallelism"""
|
|
|
if self.config.zero_stage > 0 and FSDP_AVAILABLE:
|
|
|
|
|
|
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
|
|
|
|
|
auto_wrap_policy = size_based_auto_wrap_policy(
|
|
|
min_num_params=1e6
|
|
|
)
|
|
|
|
|
|
self.model = FSDP(
|
|
|
self.model,
|
|
|
auto_wrap_policy=auto_wrap_policy,
|
|
|
cpu_offload=CPUOffload(offload_params=self.config.offload_param),
|
|
|
sharding_strategy=ShardingStrategy.FULL_SHARD if self.config.zero_stage == 3 else ShardingStrategy.SHARD_GRAD_OP,
|
|
|
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
|
|
mixed_precision=self._get_mixed_precision_policy(),
|
|
|
device_id=torch.cuda.current_device(),
|
|
|
process_group=self.pg_manager.get_group(ParallelismType.DATA)
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
self.model = DDP(
|
|
|
self.model,
|
|
|
device_ids=[torch.cuda.current_device()],
|
|
|
output_device=torch.cuda.current_device(),
|
|
|
process_group=self.pg_manager.get_group(ParallelismType.DATA),
|
|
|
gradient_as_bucket_view=self.config.gradient_as_bucket_view,
|
|
|
find_unused_parameters=self.config.find_unused_parameters,
|
|
|
broadcast_buffers=self.config.broadcast_buffers
|
|
|
)
|
|
|
|
|
|
def _apply_sequence_parallelism(self):
|
|
|
"""Apply sequence parallelism to attention layers"""
|
|
|
|
|
|
for name, module in self.model.named_modules():
|
|
|
if hasattr(module, 'self_attn') or 'attention' in name.lower():
|
|
|
|
|
|
logger.info(f"Would apply sequence parallelism to {name}")
|
|
|
|
|
|
def _get_mixed_precision_policy(self):
|
|
|
"""Get mixed precision policy for FSDP"""
|
|
|
from torch.distributed.fsdp import MixedPrecision
|
|
|
|
|
|
if self.config.mixed_precision == "fp16":
|
|
|
return MixedPrecision(
|
|
|
param_dtype=torch.float16,
|
|
|
reduce_dtype=torch.float16,
|
|
|
buffer_dtype=torch.float16
|
|
|
)
|
|
|
elif self.config.mixed_precision == "bf16":
|
|
|
return MixedPrecision(
|
|
|
param_dtype=torch.bfloat16,
|
|
|
reduce_dtype=torch.bfloat16,
|
|
|
buffer_dtype=torch.bfloat16
|
|
|
)
|
|
|
return None
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
|
|
"""Forward pass with memory optimization"""
|
|
|
with self.memory_optimizer.activation_checkpointing_context():
|
|
|
return self.model(*args, **kwargs)
|
|
|
|
|
|
|
|
|
class TensorParallelLinear(nn.Module):
|
|
|
"""Tensor parallel linear layer"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_features: int,
|
|
|
out_features: int,
|
|
|
bias: bool = True,
|
|
|
gather_output: bool = True,
|
|
|
tensor_parallel_size: int = 1
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.in_features = in_features
|
|
|
self.out_features = out_features
|
|
|
self.gather_output = gather_output
|
|
|
self.tensor_parallel_size = tensor_parallel_size
|
|
|
|
|
|
|
|
|
self.out_features_per_partition = out_features // tensor_parallel_size
|
|
|
|
|
|
|
|
|
self.weight = nn.Parameter(
|
|
|
torch.empty(self.out_features_per_partition, in_features)
|
|
|
)
|
|
|
|
|
|
if bias:
|
|
|
self.bias = nn.Parameter(torch.empty(self.out_features_per_partition))
|
|
|
else:
|
|
|
self.register_parameter('bias', None)
|
|
|
|
|
|
|
|
|
self.reset_parameters()
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
"""Initialize parameters"""
|
|
|
nn.init.xavier_uniform_(self.weight)
|
|
|
if self.bias is not None:
|
|
|
nn.init.zeros_(self.bias)
|
|
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
|
"""Forward pass with tensor parallelism"""
|
|
|
|
|
|
output = F.linear(input, self.weight, self.bias)
|
|
|
|
|
|
|
|
|
if self.gather_output and self.tensor_parallel_size > 1:
|
|
|
output_list = [torch.empty_like(output) for _ in range(self.tensor_parallel_size)]
|
|
|
dist.all_gather(output_list, output)
|
|
|
output = torch.cat(output_list, dim=-1)
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
class DistributedTrainer:
|
|
|
"""Trainer for 4D parallel training"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
model: nn.Module,
|
|
|
config: DistributedConfig,
|
|
|
optimizer_class=torch.optim.AdamW,
|
|
|
**optimizer_kwargs
|
|
|
):
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
self.model = FourDParallelModel(model, config)
|
|
|
|
|
|
|
|
|
self.optimizer = self._create_optimizer(optimizer_class, **optimizer_kwargs)
|
|
|
|
|
|
|
|
|
self.scaler = self.model.memory_optimizer.setup_mixed_precision()
|
|
|
|
|
|
|
|
|
self.step = 0
|
|
|
|
|
|
def _create_optimizer(self, optimizer_class, **kwargs):
|
|
|
"""Create optimizer with ZeRO if needed"""
|
|
|
if self.config.zero_stage > 0:
|
|
|
from torch.distributed.optim import ZeroRedundancyOptimizer
|
|
|
return ZeroRedundancyOptimizer(
|
|
|
self.model.parameters(),
|
|
|
optimizer_class=optimizer_class,
|
|
|
**kwargs
|
|
|
)
|
|
|
else:
|
|
|
return optimizer_class(self.model.parameters(), **kwargs)
|
|
|
|
|
|
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
|
|
"""Single training step"""
|
|
|
self.model.train()
|
|
|
|
|
|
|
|
|
batch = {k: v.cuda() for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
if self.scaler:
|
|
|
with torch.cuda.amp.autocast(dtype=torch.float16):
|
|
|
outputs = self.model(**batch)
|
|
|
loss = outputs['loss']
|
|
|
else:
|
|
|
outputs = self.model(**batch)
|
|
|
loss = outputs['loss']
|
|
|
|
|
|
|
|
|
loss = loss / self.config.gradient_accumulation_steps
|
|
|
|
|
|
|
|
|
if self.scaler:
|
|
|
self.scaler.scale(loss).backward()
|
|
|
else:
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
if (self.step + 1) % self.config.gradient_accumulation_steps == 0:
|
|
|
|
|
|
if self.config.gradient_clipping > 0:
|
|
|
if self.scaler:
|
|
|
self.scaler.unscale_(self.optimizer)
|
|
|
torch.nn.utils.clip_grad_norm_(
|
|
|
self.model.parameters(),
|
|
|
self.config.gradient_clipping
|
|
|
)
|
|
|
|
|
|
|
|
|
if self.scaler:
|
|
|
self.scaler.step(self.optimizer)
|
|
|
self.scaler.update()
|
|
|
else:
|
|
|
self.optimizer.step()
|
|
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
|
self.step += 1
|
|
|
|
|
|
|
|
|
if self.step % self.config.checkpoint_interval == 0:
|
|
|
self.save_checkpoint()
|
|
|
|
|
|
return {
|
|
|
'loss': loss.item() * self.config.gradient_accumulation_steps,
|
|
|
'step': self.step
|
|
|
}
|
|
|
|
|
|
def save_checkpoint(self):
|
|
|
"""Save distributed checkpoint"""
|
|
|
checkpoint_path = os.path.join(
|
|
|
self.config.checkpoint_dir,
|
|
|
f"checkpoint_step_{self.step}"
|
|
|
)
|
|
|
|
|
|
if dist.get_rank() == 0:
|
|
|
os.makedirs(checkpoint_path, exist_ok=True)
|
|
|
|
|
|
if self.config.use_sharded_checkpointing:
|
|
|
|
|
|
torch.save({
|
|
|
'model_state_dict': self.model.state_dict(),
|
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
|
'step': self.step,
|
|
|
'config': self.config
|
|
|
}, f"{checkpoint_path}/shard_{dist.get_rank()}.pt")
|
|
|
else:
|
|
|
|
|
|
if dist.get_rank() == 0:
|
|
|
torch.save({
|
|
|
'model_state_dict': self.model.state_dict(),
|
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
|
'step': self.step,
|
|
|
'config': self.config
|
|
|
}, f"{checkpoint_path}/checkpoint.pt")
|
|
|
|
|
|
logger.info(f"Saved checkpoint at step {self.step}")
|
|
|
|