| | import torch |
| | import torch.nn as nn |
| | import torch.distributed as dist |
| | from typing import List, Optional, Dict, Any, Tuple |
| | import logging |
| | import os |
| | from contextlib import contextmanager |
| |
|
| | from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy |
| | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
| | try: |
| | from torch.distributed.pipeline.sync import Pipe |
| | from torch.distributed._pipeline.sync import balance |
| | except Exception: |
| | Pipe = None |
| | balance = None |
| |
|
| | from .model import BitTransformerLM, LoggingTransformerEncoderLayer |
| | from .error_handling import with_error_recovery, safe_operation |
| | from .types import DeviceType, WorldSize, ProcessRank |
| |
|
| |
|
| | @with_error_recovery(max_retries=2) |
| | def setup_distributed(rank: ProcessRank = 0, |
| | world_size: WorldSize = 1, |
| | backend: str = "nccl", |
| | init_method: str = "tcp://localhost:23456") -> bool: |
| | """Initialize distributed training environment.""" |
| | if world_size <= 1: |
| | return False |
| | |
| | try: |
| | dist.init_process_group( |
| | backend=backend, |
| | init_method=init_method, |
| | world_size=world_size, |
| | rank=rank |
| | ) |
| | logging.info(f"Initialized distributed training: rank {rank}/{world_size}") |
| | return True |
| | except Exception as e: |
| | logging.error(f"Failed to initialize distributed training: {e}") |
| | return False |
| |
|
| |
|
| | def wrap_fsdp(model: BitTransformerLM, |
| | sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD, |
| | **kwargs) -> FullyShardedDataParallel: |
| | """Return an optimized FSDP wrapped model with transformer-aware sharding.""" |
| | device = kwargs.pop("device_id", None) |
| | if device is None and torch.cuda.is_available(): |
| | device = torch.cuda.current_device() |
| | |
| | |
| | fsdp_config = { |
| | "sharding_strategy": sharding_strategy, |
| | "cpu_offload": kwargs.pop("cpu_offload", None), |
| | "mixed_precision": kwargs.pop("mixed_precision", None), |
| | "auto_wrap_policy": transformer_auto_wrap_policy, |
| | "backward_prefetch": kwargs.pop("backward_prefetch", None), |
| | "forward_prefetch": kwargs.pop("forward_prefetch", False), |
| | "limit_all_gathers": kwargs.pop("limit_all_gathers", True), |
| | "use_orig_params": kwargs.pop("use_orig_params", True), |
| | **kwargs |
| | } |
| | |
| | |
| | fsdp_config = {k: v for k, v in fsdp_config.items() if v is not None} |
| | |
| | if device is not None: |
| | model = model.to(device) |
| | fsdp_config["device_id"] = device |
| | |
| | return FullyShardedDataParallel(model, **fsdp_config) |
| |
|
| |
|
| | class OptimizedPipeline(nn.Module): |
| | """Enhanced pipeline parallelism with BitTransformerLM optimizations.""" |
| | |
| | def __init__(self, |
| | model: BitTransformerLM, |
| | num_stages: int = 1, |
| | chunks: int = 1, |
| | checkpoint: bool = True): |
| | super().__init__() |
| | |
| | if Pipe is None: |
| | raise RuntimeError("Pipeline parallelism not available in this build") |
| | |
| | self.num_stages = num_stages |
| | self.chunks = chunks |
| | self.checkpoint = checkpoint |
| | |
| | |
| | if num_stages > 1: |
| | self.pipeline_model = self._create_pipeline_stages(model, num_stages) |
| | else: |
| | self.pipeline_model = Pipe(nn.Sequential(model), chunks=chunks) |
| | |
| | def _create_pipeline_stages(self, model: BitTransformerLM, num_stages: int) -> Pipe: |
| | """Create optimized pipeline stages for BitTransformerLM.""" |
| | |
| | layers = [] |
| | |
| | |
| | if hasattr(model, 'embedding'): |
| | layers.append(model.embedding) |
| | if hasattr(model, 'pos_encoding'): |
| | layers.append(model.pos_encoding) |
| | |
| | |
| | if hasattr(model, 'layers'): |
| | layers.extend(model.layers) |
| | elif hasattr(model, 'transformer'): |
| | layers.extend(model.transformer.layers) |
| | |
| | |
| | if hasattr(model, 'output_projection'): |
| | layers.append(model.output_projection) |
| | |
| | |
| | if balance is not None: |
| | partitions = balance(len(layers), num_stages) |
| | else: |
| | |
| | layers_per_stage = len(layers) // num_stages |
| | partitions = [layers_per_stage] * num_stages |
| | partitions[-1] += len(layers) % num_stages |
| | |
| | |
| | stages = [] |
| | start_idx = 0 |
| | for partition_size in partitions: |
| | end_idx = start_idx + partition_size |
| | stage_layers = layers[start_idx:end_idx] |
| | stages.append(nn.Sequential(*stage_layers)) |
| | start_idx = end_idx |
| | |
| | return Pipe(nn.Sequential(*stages), chunks=self.chunks) |
| | |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """Forward pass through pipeline.""" |
| | return self.pipeline_model(x) |
| |
|
| |
|
| | def make_pipeline(model: BitTransformerLM, |
| | chunks: int = 1, |
| | num_stages: int = 1, |
| | checkpoint: bool = True) -> OptimizedPipeline: |
| | """Create an optimized pipeline with advanced parallelism features.""" |
| | return OptimizedPipeline( |
| | model=model, |
| | num_stages=num_stages, |
| | chunks=chunks, |
| | checkpoint=checkpoint |
| | ) |
| |
|
| |
|
| | class DistributedTrainingManager: |
| | """Manages distributed training configuration and optimization.""" |
| | |
| | def __init__(self, |
| | world_size: WorldSize, |
| | rank: ProcessRank, |
| | use_pipeline: bool = False, |
| | use_fsdp: bool = True): |
| | self.world_size = world_size |
| | self.rank = rank |
| | self.use_pipeline = use_pipeline |
| | self.use_fsdp = use_fsdp |
| | self.is_distributed = world_size > 1 |
| | |
| | self.logger = logging.getLogger(__name__) |
| | |
| | def setup_model(self, |
| | model: BitTransformerLM, |
| | pipeline_stages: int = 1, |
| | fsdp_config: Optional[Dict[str, Any]] = None) -> nn.Module: |
| | """Set up model for distributed training.""" |
| | if not self.is_distributed: |
| | return model |
| | |
| | with safe_operation("distributed_model_setup"): |
| | if self.use_pipeline and pipeline_stages > 1: |
| | self.logger.info(f"Setting up pipeline parallelism with {pipeline_stages} stages") |
| | return make_pipeline( |
| | model, |
| | chunks=2, |
| | num_stages=pipeline_stages |
| | ) |
| | |
| | elif self.use_fsdp: |
| | self.logger.info("Setting up FSDP for data parallelism") |
| | fsdp_config = fsdp_config or {} |
| | return wrap_fsdp(model, **fsdp_config) |
| | |
| | else: |
| | self.logger.info("Using standard DistributedDataParallel") |
| | return nn.parallel.DistributedDataParallel(model) |
| | |
| | def optimize_communication(self, model: nn.Module) -> None: |
| | """Apply communication optimizations for distributed training.""" |
| | if not self.is_distributed: |
| | return |
| | |
| | |
| | if isinstance(model, nn.parallel.DistributedDataParallel): |
| | |
| | model._set_ddp_bucket_cap_mb(25) |
| | |
| | |
| | try: |
| | if hasattr(model, '_register_comm_hook'): |
| | from torch.distributed.algorithms.ddp_comm_hooks import default |
| | model.register_comm_hook( |
| | dist.group.WORLD, |
| | default.fp16_compress_hook |
| | ) |
| | except ImportError: |
| | pass |
| | |
| | @contextmanager |
| | def training_context(self): |
| | """Context manager for distributed training setup.""" |
| | try: |
| | if self.is_distributed: |
| | self.logger.info("Entering distributed training context") |
| | |
| | if torch.cuda.is_available(): |
| | torch.cuda.set_device(self.rank) |
| | yield |
| | finally: |
| | if self.is_distributed: |
| | self.logger.info("Exiting distributed training context") |
| |
|
| |
|
| | def cleanup_distributed(): |
| | """Clean up distributed training environment.""" |
| | if dist.is_initialized(): |
| | dist.destroy_process_group() |
| | logging.info("Distributed training cleaned up") |
| |
|
| |
|
| | def get_distributed_config() -> Dict[str, Any]: |
| | """Get current distributed training configuration.""" |
| | if not dist.is_initialized(): |
| | return {"distributed": False} |
| | |
| | return { |
| | "distributed": True, |
| | "world_size": dist.get_world_size(), |
| | "rank": dist.get_rank(), |
| | "backend": dist.get_backend(), |
| | "local_rank": int(os.environ.get("LOCAL_RANK", 0)) if "LOCAL_RANK" in os.environ else None, |
| | } |
| |
|
| |
|
| | |
| | def all_reduce_tensor(tensor: torch.Tensor, |
| | op: dist.ReduceOp = dist.ReduceOp.SUM) -> torch.Tensor: |
| | """All-reduce operation on tensor across all processes.""" |
| | if not dist.is_initialized(): |
| | return tensor |
| | |
| | dist.all_reduce(tensor, op=op) |
| | return tensor |
| |
|
| |
|
| | def gather_tensors(tensor: torch.Tensor, |
| | dst: int = 0) -> Optional[List[torch.Tensor]]: |
| | """Gather tensors from all processes to destination rank.""" |
| | if not dist.is_initialized(): |
| | return [tensor] |
| | |
| | if dist.get_rank() == dst: |
| | tensor_list = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())] |
| | dist.gather(tensor, tensor_list, dst=dst) |
| | return tensor_list |
| | else: |
| | dist.gather(tensor, dst=dst) |
| | return None |
| |
|
| |
|
| | def broadcast_tensor(tensor: torch.Tensor, src: int = 0) -> torch.Tensor: |
| | """Broadcast tensor from source rank to all processes.""" |
| | if not dist.is_initialized(): |
| | return tensor |
| | |
| | dist.broadcast(tensor, src=src) |
| | return tensor |
| |
|
| |
|
| | |
| | class PipelineScheduler: |
| | """Advanced scheduler for pipeline parallelism with load balancing.""" |
| | |
| | def __init__(self, num_stages: int, world_size: int): |
| | self.num_stages = num_stages |
| | self.world_size = world_size |
| | self.stage_times = [0.0] * num_stages |
| | self.load_balance_enabled = True |
| | |
| | def update_stage_timing(self, stage_id: int, execution_time: float): |
| | """Update execution time for a pipeline stage.""" |
| | if 0 <= stage_id < self.num_stages: |
| | |
| | alpha = 0.1 |
| | self.stage_times[stage_id] = (1 - alpha) * self.stage_times[stage_id] + alpha * execution_time |
| | |
| | def get_optimal_chunks(self, batch_size: int) -> int: |
| | """Calculate optimal number of chunks based on stage timing.""" |
| | if not self.load_balance_enabled: |
| | return max(1, batch_size // 8) |
| | |
| | |
| | max_stage_time = max(self.stage_times) if any(self.stage_times) else 1.0 |
| | avg_stage_time = sum(self.stage_times) / len(self.stage_times) if self.stage_times else 1.0 |
| | |
| | |
| | imbalance_factor = max_stage_time / max(avg_stage_time, 1e-6) |
| | optimal_chunks = max(2, min(batch_size, int(4 * imbalance_factor))) |
| | |
| | return optimal_chunks |
| |
|
| |
|
| | |
| | def efficient_gradient_sync(model: nn.Module, gradient_clipping: float = 1.0): |
| | """Perform memory-efficient gradient synchronization across processes.""" |
| | if not dist.is_initialized(): |
| | return |
| | |
| | |
| | if gradient_clipping > 0: |
| | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) |
| | |
| | |
| | if dist.get_rank() == 0: |
| | logging.debug(f"Gradient norm before clipping: {total_norm.item():.4f}") |
| | |
| | |
| | bucket_size_mb = 25 |
| | parameters = list(model.parameters()) |
| | |
| | for param in parameters: |
| | if param.grad is not None: |
| | |
| | dist.all_reduce(param.grad, async_op=False) |
| | param.grad /= dist.get_world_size() |
| |
|
| |
|
| | |
| | class DistributedMemoryManager: |
| | """Manages memory efficiently across distributed processes.""" |
| | |
| | def __init__(self, enable_cpu_offload: bool = False): |
| | self.enable_cpu_offload = enable_cpu_offload |
| | self.memory_stats = {} |
| | self.peak_memory = 0 |
| | |
| | def monitor_memory(self): |
| | """Monitor GPU memory usage across processes.""" |
| | if torch.cuda.is_available(): |
| | current_memory = torch.cuda.memory_allocated() |
| | max_memory = torch.cuda.max_memory_allocated() |
| | |
| | self.memory_stats = { |
| | "current_gb": current_memory / 1e9, |
| | "peak_gb": max_memory / 1e9, |
| | "rank": dist.get_rank() if dist.is_initialized() else 0 |
| | } |
| | |
| | self.peak_memory = max(self.peak_memory, current_memory) |
| | |
| | def optimize_memory_usage(self): |
| | """Apply memory optimizations based on current usage.""" |
| | if torch.cuda.is_available(): |
| | |
| | if torch.cuda.memory_allocated() > 0.8 * torch.cuda.max_memory_allocated(): |
| | torch.cuda.empty_cache() |
| | logging.info("Cleared CUDA cache due to high memory usage") |
| | |
| | def get_memory_report(self) -> Dict[str, float]: |
| | """Get comprehensive memory usage report.""" |
| | self.monitor_memory() |
| | return self.memory_stats |
| |
|
| |
|
| | |
| | pipeline_scheduler = PipelineScheduler(num_stages=1, world_size=1) |
| | memory_manager = DistributedMemoryManager() |
| |
|
| |
|
| | def setup_advanced_distributed_training( |
| | rank: ProcessRank, |
| | world_size: WorldSize, |
| | enable_memory_monitoring: bool = True, |
| | enable_pipeline_scheduling: bool = True |
| | ) -> Dict[str, Any]: |
| | """Set up advanced distributed training with optimizations.""" |
| | global pipeline_scheduler, memory_manager |
| | |
| | |
| | success = setup_distributed(rank, world_size) |
| | if not success: |
| | return {"distributed": False} |
| | |
| | |
| | if enable_pipeline_scheduling: |
| | pipeline_scheduler = PipelineScheduler(num_stages=world_size, world_size=world_size) |
| | |
| | if enable_memory_monitoring: |
| | memory_manager = DistributedMemoryManager() |
| | memory_manager.monitor_memory() |
| | |
| | config = get_distributed_config() |
| | config.update({ |
| | "pipeline_scheduling": enable_pipeline_scheduling, |
| | "memory_monitoring": enable_memory_monitoring, |
| | "advanced_features": True |
| | }) |
| | |
| | logging.info(f"Advanced distributed training initialized on rank {rank}") |
| | return config |
| |
|