""" Fully Sharded Data Parallel (FSDP) utilities for training large models. FSDP shards model parameters, gradients, and optimizer states across GPUs, allowing training of models that don't fit on a single GPU. Requires: PyTorch 2.0+ with FSDP support """ import logging from pathlib import Path from typing import Optional try: import torch # type: ignore[import-not-found] import torch.nn as nn # type: ignore[import-not-found] except Exception: # pragma: no cover torch = None # type: ignore nn = None # type: ignore logger = logging.getLogger(__name__) # Try to import FSDP try: from torch.distributed.fsdp import BackwardPrefetch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision, ShardingStrategy # Note: transformer_auto_wrap_policy typically needs a partial() with transformer layer classes. # We intentionally do not auto-detect layer classes in this repo. FSDP_AVAILABLE = True except Exception: # pragma: no cover FSDP_AVAILABLE = False logger.warning("FSDP not available. Requires PyTorch 2.0+ with distributed support.") def wrap_model_fsdp( model: nn.Module, sharding_strategy: str = "FULL_SHARD", mixed_precision: Optional[str] = "bf16", auto_wrap_policy: Optional[str] = None, device_id: Optional[int] = None, *, use_orig_params: bool = True, limit_all_gathers: bool = True, forward_prefetch: bool = True, backward_prefetch: Optional[str] = "BACKWARD_PRE", sync_module_states: bool = True, ) -> nn.Module: """ Wrap model with FSDP for memory-efficient distributed training. Args: model: Model to wrap sharding_strategy: Sharding strategy: - "FULL_SHARD": Shard parameters, gradients, optimizer states (most memory efficient) - "SHARD_GRAD_OP": Shard gradients and optimizer states only - "NO_SHARD": Don't shard (equivalent to DDP) mixed_precision: Mixed precision mode: "bf16", "fp16", or None auto_wrap_policy: Auto-wrap policy: "transformer" or None device_id: Device ID for this process Returns: FSDP-wrapped model """ if torch is None or nn is None or not FSDP_AVAILABLE: logger.warning("FSDP not available, returning unwrapped model") return model import torch.distributed as dist if not dist.is_initialized(): logger.warning("Distributed not initialized, cannot use FSDP") return model # Convert sharding strategy strategy_map = { "FULL_SHARD": ShardingStrategy.FULL_SHARD, "SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP, "NO_SHARD": ShardingStrategy.NO_SHARD, } sharding = strategy_map.get(sharding_strategy, ShardingStrategy.FULL_SHARD) # Setup mixed precision mp_policy = None if mixed_precision == "bf16": mp_policy = MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16, ) elif mixed_precision == "fp16": mp_policy = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float32, # Keep buffers in FP32 for stability ) # Auto-wrap policy for transformer layers wrap_policy = None if auto_wrap_policy == "transformer": logger.warning( "auto_wrap_policy='transformer' requested but not configured in this repo. " "Pass an explicit wrap policy or keep auto_wrap_policy=None." ) bp = None if backward_prefetch is not None: bp_map = { "BACKWARD_PRE": getattr(BackwardPrefetch, "BACKWARD_PRE", None), "BACKWARD_POST": getattr(BackwardPrefetch, "BACKWARD_POST", None), } bp = bp_map.get(str(backward_prefetch)) # Wrap model fsdp_model = FSDP( model, sharding_strategy=sharding, mixed_precision=mp_policy, auto_wrap_policy=wrap_policy, device_id=device_id, use_orig_params=bool(use_orig_params), limit_all_gathers=bool(limit_all_gathers), forward_prefetch=bool(forward_prefetch), backward_prefetch=bp, sync_module_states=bool(sync_module_states), ) logger.info( f"Model wrapped with FSDP: strategy={sharding_strategy}, " f"mixed_precision={mixed_precision}" ) return fsdp_model def get_fsdp_memory_info(model: nn.Module) -> dict: """ Get memory usage information for FSDP model. Args: model: FSDP-wrapped model Returns: Dict with memory statistics """ if not isinstance(model, FSDP): return {"error": "Model is not wrapped with FSDP"} # Get memory stats from FSDP try: pass # This is a simplified version - actual memory tracking is more complex return { "is_fsdp": True, "sharding_strategy": str(model.sharding_strategy), "mixed_precision": str(model.mixed_precision), } except Exception as e: logger.warning(f"Could not get FSDP memory info: {e}") return {"error": str(e)} def save_fsdp_checkpoint( model: nn.Module, optimizer, epoch: int, checkpoint_path: str, rank: int = 0, ): """ Save FSDP checkpoint (only on rank 0 to avoid conflicts). Args: model: FSDP-wrapped model optimizer: Optimizer epoch: Current epoch checkpoint_path: Path to save checkpoint rank: Process rank (only rank 0 saves) """ if not isinstance(model, FSDP): logger.warning("Model is not FSDP-wrapped, using standard checkpoint save") if int(rank) == 0: torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, checkpoint_path, ) return # For FSDP, we need to gather full state dict from torch.distributed.fsdp import FullStateDictConfig, StateDictType save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): model_state = model.state_dict() optimizer_state = FSDP.full_optim_state_dict(model, optimizer) if int(rank) == 0: torch.save( { "epoch": epoch, "model_state_dict": model_state, "optimizer_state_dict": optimizer_state, }, checkpoint_path, ) logger.info(f"Saved FSDP checkpoint to {checkpoint_path}") def save_fsdp_checkpoint_sharded_dir( model: nn.Module, optimizer, epoch: int, checkpoint_dir: str, *, rank: int = 0, ): """ Save a sharded checkpoint directory using torch.distributed.checkpoint when available. This is the recommended path for large-scale FSDP training. """ if not isinstance(model, FSDP): # Fallback: single file checkpoint. ckpt_path = str(Path(checkpoint_dir) / "checkpoint.pt") torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, ckpt_path, ) return try: import torch.distributed.checkpoint as dcp # type: ignore from torch.distributed.checkpoint import FileSystemWriter # type: ignore from torch.distributed.checkpoint.state_dict import ( # type: ignore get_state_dict, set_state_dict, ) except Exception: # Conservative fallback: gather full state dict on rank0_only. # This is slower but keeps functionality if DCP is unavailable. ckpt_path = str(Path(checkpoint_dir) / "checkpoint_full.pt") save_fsdp_checkpoint(model, optimizer, epoch, ckpt_path, rank=int(rank)) return out_dir = Path(checkpoint_dir) out_dir.mkdir(parents=True, exist_ok=True) state = get_state_dict(model, optimizer) dcp.save_state_dict( state_dict=state, storage_writer=FileSystemWriter(str(out_dir)), ) # Ensure any internal buffers are consistent after save. set_state_dict(model, optimizer, state) # Persist small metadata once (avoid multiple writers). try: import torch.distributed as dist # type: ignore if dist.is_initialized(): dist.barrier() if int(rank) == 0: torch.save({"epoch": int(epoch)}, str(out_dir / "meta.pt")) (out_dir / "SUCCESS").write_text("ok\n") dist.barrier() elif int(rank) == 0: torch.save({"epoch": int(epoch)}, str(out_dir / "meta.pt")) (out_dir / "SUCCESS").write_text("ok\n") except Exception: if int(rank) == 0: torch.save({"epoch": int(epoch)}, str(out_dir / "meta.pt")) (out_dir / "SUCCESS").write_text("ok\n") def load_fsdp_checkpoint_sharded_dir( model: nn.Module, optimizer, checkpoint_dir: str, *, rank: int = 0, ) -> int: """ Load a sharded checkpoint directory saved by save_fsdp_checkpoint_sharded_dir(). """ if not isinstance(model, FSDP): ckpt_path = str(Path(checkpoint_dir) / "checkpoint.pt") checkpoint = torch.load(ckpt_path, map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) return int(checkpoint.get("epoch", 0)) try: import torch.distributed.checkpoint as dcp # type: ignore from torch.distributed.checkpoint import FileSystemReader # type: ignore from torch.distributed.checkpoint.state_dict import ( # type: ignore get_state_dict, set_state_dict, ) except Exception: # Fallback: full checkpoint path. ckpt_path = str(Path(checkpoint_dir) / "checkpoint_full.pt") return int(load_fsdp_checkpoint(model, optimizer, ckpt_path, rank=int(rank))) in_dir = Path(checkpoint_dir) state = get_state_dict(model, optimizer) dcp.load_state_dict( state_dict=state, storage_reader=FileSystemReader(str(in_dir)), ) set_state_dict(model, optimizer, state) meta_path = in_dir / "meta.pt" if meta_path.exists(): meta = torch.load(str(meta_path), map_location="cpu") return int(meta.get("epoch", 0)) return 0 def load_fsdp_checkpoint( model: nn.Module, optimizer, checkpoint_path: str, rank: int = 0, ): """ Load FSDP checkpoint. Args: model: FSDP-wrapped model optimizer: Optimizer checkpoint_path: Path to checkpoint rank: Process rank """ if not isinstance(model, FSDP): logger.warning("Model is not FSDP-wrapped, using standard checkpoint load") checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) return checkpoint.get("epoch", 0) # Load checkpoint on rank0 and broadcast to all ranks. try: import torch.distributed as dist # type: ignore except Exception: # pragma: no cover dist = None checkpoint = None if int(rank) == 0: checkpoint = torch.load(checkpoint_path, map_location="cpu") if dist is not None and getattr(dist, "is_initialized", lambda: False)(): obj_list = [checkpoint] dist.broadcast_object_list(obj_list, src=0) checkpoint = obj_list[0] if checkpoint is None: raise RuntimeError(f"Failed to load checkpoint: {checkpoint_path}") # Load model state dict from torch.distributed.fsdp import FullStateDictConfig, StateDictType load_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, load_policy): model.load_state_dict(checkpoint["model_state_dict"]) # Load optimizer state dict sharded_optim_state = FSDP.shard_full_optim_state_dict( checkpoint["optimizer_state_dict"], model ) optimizer.load_state_dict(sharded_optim_state) logger.info(f"Loaded FSDP checkpoint from {checkpoint_path}") return checkpoint.get("epoch", 0)