Vedisasi's picture
Upload folder using huggingface_hub
54c5666 verified
"""
4D Parallelism Distributed Training Infrastructure
Implements Data, Tensor, Pipeline, and Expert Parallelism with Sequence Parallelism
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from typing import Dict, List, Tuple, Optional, Any, Union
from dataclasses import dataclass, field
from enum import Enum
import os
import math
import numpy as np
from contextlib import contextmanager
import logging
# Try to import FSDP components (may not be available in all PyTorch versions)
try:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import CPUOffload, BackwardPrefetch, ShardingStrategy, MixedPrecision
FSDP_AVAILABLE = True
except ImportError:
FSDP_AVAILABLE = False
logger = logging.getLogger(__name__)
logger.warning("FSDP not available, falling back to DDP")
logger = logging.getLogger(__name__)
class ParallelismType(Enum):
"""Types of parallelism"""
DATA = "data"
TENSOR = "tensor"
PIPELINE = "pipeline"
EXPERT = "expert"
SEQUENCE = "sequence"
ZERO = "zero"
@dataclass
class DistributedConfig:
"""Configuration for distributed training"""
# Parallelism dimensions
data_parallel_size: int = 1
tensor_parallel_size: int = 1
pipeline_parallel_size: int = 1
expert_parallel_size: int = 1
# Sequence parallelism
enable_sequence_parallel: bool = True
sequence_parallel_size: int = 1
max_sequence_length: int = 8192
# ZeRO optimization
zero_stage: int = 3
offload_optimizer: bool = True
offload_param: bool = False
# Memory optimization
activation_checkpointing: bool = True
cpu_offload: bool = True
gradient_compression: bool = True
mixed_precision: str = "bf16"
# Communication
backend: str = "nccl"
gradient_as_bucket_view: bool = True
find_unused_parameters: bool = False
broadcast_buffers: bool = True
# Pipeline settings
num_microbatches: int = 4
pipeline_schedule: str = "1f1b"
# Expert settings
expert_capacity_factor: float = 1.25
expert_drop_policy: str = "probs"
# Optimization
gradient_accumulation_steps: int = 1
gradient_clipping: float = 1.0
# Checkpointing
checkpoint_dir: str = "./checkpoints"
checkpoint_interval: int = 1000
use_sharded_checkpointing: bool = True
class ProcessGroupManager:
"""Manage process groups for different parallelism types"""
def __init__(self, config: DistributedConfig):
self.config = config
self.world_size = dist.get_world_size()
self.rank = dist.get_rank()
# Calculate total parallelism
self.total_parallel_size = (
config.data_parallel_size *
config.tensor_parallel_size *
config.pipeline_parallel_size
)
# Initialize process groups
self.process_groups = {}
self._initialize_process_groups()
def _initialize_process_groups(self):
"""Initialize all process groups"""
# Data parallel groups
dp_ranks = self._get_data_parallel_ranks()
self.process_groups[ParallelismType.DATA] = dist.new_group(dp_ranks)
# Tensor parallel groups
tp_ranks = self._get_tensor_parallel_ranks()
self.process_groups[ParallelismType.TENSOR] = dist.new_group(tp_ranks)
# Pipeline parallel groups
pp_ranks = self._get_pipeline_parallel_ranks()
self.process_groups[ParallelismType.PIPELINE] = dist.new_group(pp_ranks)
# Sequence parallel groups (same as tensor parallel)
if self.config.enable_sequence_parallel:
self.process_groups[ParallelismType.SEQUENCE] = self.process_groups[ParallelismType.TENSOR]
def _get_data_parallel_ranks(self) -> List[int]:
"""Get ranks for data parallel group"""
dp_size = self.config.data_parallel_size
tp_size = self.config.tensor_parallel_size
pp_size = self.config.pipeline_parallel_size
# Calculate which ranks belong to same DP group
dp_group_idx = self.rank // (tp_size * pp_size)
ranks = []
for r in range(self.world_size):
if r // (tp_size * pp_size) == dp_group_idx:
ranks.append(r)
return ranks
def _get_tensor_parallel_ranks(self) -> List[int]:
"""Get ranks for tensor parallel group"""
tp_size = self.config.tensor_parallel_size
pp_size = self.config.pipeline_parallel_size
tp_group_idx = (self.rank // pp_size) % tp_size
ranks = []
for r in range(self.world_size):
if (r // pp_size) % tp_size == tp_group_idx:
ranks.append(r)
return ranks
def _get_pipeline_parallel_ranks(self) -> List[int]:
"""Get ranks for pipeline parallel group"""
pp_size = self.config.pipeline_parallel_size
pp_group_idx = self.rank % pp_size
ranks = []
for r in range(self.world_size):
if r % pp_size == pp_group_idx:
ranks.append(r)
return ranks
def get_group(self, parallelism_type: ParallelismType):
"""Get process group for given parallelism type"""
return self.process_groups.get(parallelism_type)
class MemoryOptimizer:
"""Memory optimization utilities"""
def __init__(self, config: DistributedConfig):
self.config = config
@contextmanager
def activation_checkpointing_context(self):
"""Context manager for activation checkpointing"""
if self.config.activation_checkpointing:
# Enable gradient checkpointing
torch.cuda.empty_cache()
yield
torch.cuda.empty_cache()
else:
yield
def apply_gradient_compression(self, gradients: torch.Tensor) -> torch.Tensor:
"""Apply gradient compression"""
if not self.config.gradient_compression:
return gradients
# PowerSGD compression (simplified)
# In practice, would use actual PowerSGD implementation
shape = gradients.shape
flattened = gradients.flatten()
# Keep only top-k gradients
k = int(flattened.numel() * 0.1) # Keep 10%
topk_vals, topk_idx = torch.topk(flattened.abs(), k)
compressed = torch.zeros_like(flattened)
compressed[topk_idx] = flattened[topk_idx]
return compressed.reshape(shape)
def setup_mixed_precision(self):
"""Setup mixed precision training"""
if self.config.mixed_precision == "fp16":
return torch.cuda.amp.GradScaler()
elif self.config.mixed_precision == "bf16":
# BF16 doesn't need loss scaling
return None
elif self.config.mixed_precision == "fp8":
# FP8 support (if available)
logger.warning("FP8 not fully supported, falling back to BF16")
return None
return None
class FourDParallelModel(nn.Module):
"""Model wrapper for 4D parallelism"""
def __init__(
self,
model: nn.Module,
config: DistributedConfig
):
super().__init__()
self.model = model
self.config = config
# Process group manager
self.pg_manager = ProcessGroupManager(config)
# Memory optimizer
self.memory_optimizer = MemoryOptimizer(config)
# Apply parallelism
self._apply_parallelism()
def _apply_parallelism(self):
"""Apply all parallelism types to model"""
# Tensor parallelism
if self.config.tensor_parallel_size > 1:
self._apply_tensor_parallelism()
# Pipeline parallelism
if self.config.pipeline_parallel_size > 1:
self._apply_pipeline_parallelism()
# Data parallelism (with DDP or FSDP)
if self.config.data_parallel_size > 1:
self._apply_data_parallelism()
# Sequence parallelism
if self.config.enable_sequence_parallel:
self._apply_sequence_parallelism()
def _apply_tensor_parallelism(self):
"""Apply tensor parallelism to model layers"""
# Replace linear layers with tensor parallel versions
for name, module in self.model.named_modules():
if isinstance(module, nn.Linear):
# Replace with tensor parallel linear
tp_linear = TensorParallelLinear(
module.in_features,
module.out_features,
bias=module.bias is not None,
tensor_parallel_size=self.config.tensor_parallel_size
)
# Copy weights
with torch.no_grad():
tp_linear.weight.copy_(module.weight)
if module.bias is not None:
tp_linear.bias.copy_(module.bias)
# Replace in model
parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
child_name = name.rsplit('.', 1)[1] if '.' in name else name
parent = self.model
if parent_name:
for part in parent_name.split('.'):
parent = getattr(parent, part)
setattr(parent, child_name, tp_linear)
def _apply_pipeline_parallelism(self):
"""Apply pipeline parallelism"""
# Split model into stages
layers = list(self.model.children())
num_layers = len(layers)
layers_per_stage = num_layers // self.config.pipeline_parallel_size
# Create pipeline stages
stages = []
for i in range(self.config.pipeline_parallel_size):
start = i * layers_per_stage
end = start + layers_per_stage if i < self.config.pipeline_parallel_size - 1 else num_layers
stage = nn.Sequential(*layers[start:end])
stages.append(stage)
# Assign stages to ranks
pp_rank = dist.get_rank() % self.config.pipeline_parallel_size
self.model = stages[pp_rank]
def _apply_data_parallelism(self):
"""Apply data parallelism"""
if self.config.zero_stage > 0 and FSDP_AVAILABLE:
# Use FSDP for ZeRO optimization
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
auto_wrap_policy = size_based_auto_wrap_policy(
min_num_params=1e6
)
self.model = FSDP(
self.model,
auto_wrap_policy=auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=self.config.offload_param),
sharding_strategy=ShardingStrategy.FULL_SHARD if self.config.zero_stage == 3 else ShardingStrategy.SHARD_GRAD_OP,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
mixed_precision=self._get_mixed_precision_policy(),
device_id=torch.cuda.current_device(),
process_group=self.pg_manager.get_group(ParallelismType.DATA)
)
else:
# Use standard DDP
self.model = DDP(
self.model,
device_ids=[torch.cuda.current_device()],
output_device=torch.cuda.current_device(),
process_group=self.pg_manager.get_group(ParallelismType.DATA),
gradient_as_bucket_view=self.config.gradient_as_bucket_view,
find_unused_parameters=self.config.find_unused_parameters,
broadcast_buffers=self.config.broadcast_buffers
)
def _apply_sequence_parallelism(self):
"""Apply sequence parallelism to attention layers"""
# Replace attention modules with sequence parallel versions
for name, module in self.model.named_modules():
if hasattr(module, 'self_attn') or 'attention' in name.lower():
# This is simplified - actual implementation would be more sophisticated
logger.info(f"Would apply sequence parallelism to {name}")
def _get_mixed_precision_policy(self):
"""Get mixed precision policy for FSDP"""
from torch.distributed.fsdp import MixedPrecision
if self.config.mixed_precision == "fp16":
return MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16
)
elif self.config.mixed_precision == "bf16":
return MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16
)
return None
def forward(self, *args, **kwargs):
"""Forward pass with memory optimization"""
with self.memory_optimizer.activation_checkpointing_context():
return self.model(*args, **kwargs)
class TensorParallelLinear(nn.Module):
"""Tensor parallel linear layer"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
gather_output: bool = True,
tensor_parallel_size: int = 1
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.tensor_parallel_size = tensor_parallel_size
# Calculate per-device dimensions
self.out_features_per_partition = out_features // tensor_parallel_size
# Create weight and bias
self.weight = nn.Parameter(
torch.empty(self.out_features_per_partition, in_features)
)
if bias:
self.bias = nn.Parameter(torch.empty(self.out_features_per_partition))
else:
self.register_parameter('bias', None)
# Initialize
self.reset_parameters()
def reset_parameters(self):
"""Initialize parameters"""
nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward pass with tensor parallelism"""
# Local computation
output = F.linear(input, self.weight, self.bias)
# Gather output if needed
if self.gather_output and self.tensor_parallel_size > 1:
output_list = [torch.empty_like(output) for _ in range(self.tensor_parallel_size)]
dist.all_gather(output_list, output)
output = torch.cat(output_list, dim=-1)
return output
class DistributedTrainer:
"""Trainer for 4D parallel training"""
def __init__(
self,
model: nn.Module,
config: DistributedConfig,
optimizer_class=torch.optim.AdamW,
**optimizer_kwargs
):
self.config = config
# Wrap model with 4D parallelism
self.model = FourDParallelModel(model, config)
# Create optimizer
self.optimizer = self._create_optimizer(optimizer_class, **optimizer_kwargs)
# Setup mixed precision
self.scaler = self.model.memory_optimizer.setup_mixed_precision()
# Stats
self.step = 0
def _create_optimizer(self, optimizer_class, **kwargs):
"""Create optimizer with ZeRO if needed"""
if self.config.zero_stage > 0:
from torch.distributed.optim import ZeroRedundancyOptimizer
return ZeroRedundancyOptimizer(
self.model.parameters(),
optimizer_class=optimizer_class,
**kwargs
)
else:
return optimizer_class(self.model.parameters(), **kwargs)
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Single training step"""
self.model.train()
# Move batch to device
batch = {k: v.cuda() for k, v in batch.items()}
# Mixed precision context
if self.scaler:
with torch.cuda.amp.autocast(dtype=torch.float16):
outputs = self.model(**batch)
loss = outputs['loss']
else:
outputs = self.model(**batch)
loss = outputs['loss']
# Scale loss for gradient accumulation
loss = loss / self.config.gradient_accumulation_steps
# Backward pass
if self.scaler:
self.scaler.scale(loss).backward()
else:
loss.backward()
# Gradient accumulation
if (self.step + 1) % self.config.gradient_accumulation_steps == 0:
# Gradient clipping
if self.config.gradient_clipping > 0:
if self.scaler:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.gradient_clipping
)
# Optimizer step
if self.scaler:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
self.step += 1
# Checkpoint if needed
if self.step % self.config.checkpoint_interval == 0:
self.save_checkpoint()
return {
'loss': loss.item() * self.config.gradient_accumulation_steps,
'step': self.step
}
def save_checkpoint(self):
"""Save distributed checkpoint"""
checkpoint_path = os.path.join(
self.config.checkpoint_dir,
f"checkpoint_step_{self.step}"
)
if dist.get_rank() == 0:
os.makedirs(checkpoint_path, exist_ok=True)
if self.config.use_sharded_checkpointing:
# Sharded checkpoint
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'step': self.step,
'config': self.config
}, f"{checkpoint_path}/shard_{dist.get_rank()}.pt")
else:
# Single checkpoint (only rank 0)
if dist.get_rank() == 0:
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'step': self.step,
'config': self.config
}, f"{checkpoint_path}/checkpoint.pt")
logger.info(f"Saved checkpoint at step {self.step}")