| | """ |
| | Fully Sharded Data Parallel (FSDP) utilities for training large models. |
| | |
| | FSDP shards model parameters, gradients, and optimizer states across GPUs, |
| | allowing training of models that don't fit on a single GPU. |
| | |
| | Requires: PyTorch 2.0+ with FSDP support |
| | """ |
| |
|
| | import logging |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | try: |
| | import torch |
| | import torch.nn as nn |
| | except Exception: |
| | torch = None |
| | nn = None |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | try: |
| | from torch.distributed.fsdp import BackwardPrefetch |
| | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy |
| |
|
| | |
| | |
| |
|
| | FSDP_AVAILABLE = True |
| | except Exception: |
| | FSDP_AVAILABLE = False |
| | logger.warning("FSDP not available. Requires PyTorch 2.0+ with distributed support.") |
| |
|
| |
|
| | def wrap_model_fsdp( |
| | model: nn.Module, |
| | sharding_strategy: str = "FULL_SHARD", |
| | mixed_precision: Optional[str] = "bf16", |
| | auto_wrap_policy: Optional[str] = None, |
| | device_id: Optional[int] = None, |
| | *, |
| | use_orig_params: bool = True, |
| | limit_all_gathers: bool = True, |
| | forward_prefetch: bool = True, |
| | backward_prefetch: Optional[str] = "BACKWARD_PRE", |
| | sync_module_states: bool = True, |
| | ) -> nn.Module: |
| | """ |
| | Wrap model with FSDP for memory-efficient distributed training. |
| | |
| | Args: |
| | model: Model to wrap |
| | sharding_strategy: Sharding strategy: |
| | - "FULL_SHARD": Shard parameters, gradients, optimizer states (most memory efficient) |
| | - "SHARD_GRAD_OP": Shard gradients and optimizer states only |
| | - "NO_SHARD": Don't shard (equivalent to DDP) |
| | mixed_precision: Mixed precision mode: "bf16", "fp16", or None |
| | auto_wrap_policy: Auto-wrap policy: "transformer" or None |
| | device_id: Device ID for this process |
| | |
| | Returns: |
| | FSDP-wrapped model |
| | """ |
| | if torch is None or nn is None or not FSDP_AVAILABLE: |
| | logger.warning("FSDP not available, returning unwrapped model") |
| | return model |
| |
|
| | import torch.distributed as dist |
| |
|
| | if not dist.is_initialized(): |
| | logger.warning("Distributed not initialized, cannot use FSDP") |
| | return model |
| |
|
| | |
| | strategy_map = { |
| | "FULL_SHARD": ShardingStrategy.FULL_SHARD, |
| | "SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP, |
| | "NO_SHARD": ShardingStrategy.NO_SHARD, |
| | } |
| | sharding = strategy_map.get(sharding_strategy, ShardingStrategy.FULL_SHARD) |
| |
|
| | |
| | mp_policy = None |
| | if mixed_precision == "bf16": |
| | mp_policy = MixedPrecision( |
| | param_dtype=torch.bfloat16, |
| | reduce_dtype=torch.bfloat16, |
| | buffer_dtype=torch.bfloat16, |
| | ) |
| | elif mixed_precision == "fp16": |
| | mp_policy = MixedPrecision( |
| | param_dtype=torch.float16, |
| | reduce_dtype=torch.float16, |
| | buffer_dtype=torch.float32, |
| | ) |
| |
|
| | |
| | wrap_policy = None |
| | if auto_wrap_policy == "transformer": |
| | logger.warning( |
| | "auto_wrap_policy='transformer' requested but not configured in this repo. " |
| | "Pass an explicit wrap policy or keep auto_wrap_policy=None." |
| | ) |
| |
|
| | bp = None |
| | if backward_prefetch is not None: |
| | bp_map = { |
| | "BACKWARD_PRE": getattr(BackwardPrefetch, "BACKWARD_PRE", None), |
| | "BACKWARD_POST": getattr(BackwardPrefetch, "BACKWARD_POST", None), |
| | } |
| | bp = bp_map.get(str(backward_prefetch)) |
| |
|
| | |
| | fsdp_model = FSDP( |
| | model, |
| | sharding_strategy=sharding, |
| | mixed_precision=mp_policy, |
| | auto_wrap_policy=wrap_policy, |
| | device_id=device_id, |
| | use_orig_params=bool(use_orig_params), |
| | limit_all_gathers=bool(limit_all_gathers), |
| | forward_prefetch=bool(forward_prefetch), |
| | backward_prefetch=bp, |
| | sync_module_states=bool(sync_module_states), |
| | ) |
| |
|
| | logger.info( |
| | f"Model wrapped with FSDP: strategy={sharding_strategy}, " |
| | f"mixed_precision={mixed_precision}" |
| | ) |
| |
|
| | return fsdp_model |
| |
|
| |
|
| | def get_fsdp_memory_info(model: nn.Module) -> dict: |
| | """ |
| | Get memory usage information for FSDP model. |
| | |
| | Args: |
| | model: FSDP-wrapped model |
| | |
| | Returns: |
| | Dict with memory statistics |
| | """ |
| | if not isinstance(model, FSDP): |
| | return {"error": "Model is not wrapped with FSDP"} |
| |
|
| | |
| | try: |
| | pass |
| |
|
| | |
| | return { |
| | "is_fsdp": True, |
| | "sharding_strategy": str(model.sharding_strategy), |
| | "mixed_precision": str(model.mixed_precision), |
| | } |
| | except Exception as e: |
| | logger.warning(f"Could not get FSDP memory info: {e}") |
| | return {"error": str(e)} |
| |
|
| |
|
| | def save_fsdp_checkpoint( |
| | model: nn.Module, |
| | optimizer, |
| | epoch: int, |
| | checkpoint_path: str, |
| | rank: int = 0, |
| | ): |
| | """ |
| | Save FSDP checkpoint (only on rank 0 to avoid conflicts). |
| | |
| | Args: |
| | model: FSDP-wrapped model |
| | optimizer: Optimizer |
| | epoch: Current epoch |
| | checkpoint_path: Path to save checkpoint |
| | rank: Process rank (only rank 0 saves) |
| | """ |
| | if not isinstance(model, FSDP): |
| | logger.warning("Model is not FSDP-wrapped, using standard checkpoint save") |
| | if int(rank) == 0: |
| | torch.save( |
| | { |
| | "epoch": epoch, |
| | "model_state_dict": model.state_dict(), |
| | "optimizer_state_dict": optimizer.state_dict(), |
| | }, |
| | checkpoint_path, |
| | ) |
| | return |
| |
|
| | |
| | from torch.distributed.fsdp import FullStateDictConfig, StateDictType |
| |
|
| | save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
| |
|
| | with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): |
| | model_state = model.state_dict() |
| | optimizer_state = FSDP.full_optim_state_dict(model, optimizer) |
| |
|
| | if int(rank) == 0: |
| | torch.save( |
| | { |
| | "epoch": epoch, |
| | "model_state_dict": model_state, |
| | "optimizer_state_dict": optimizer_state, |
| | }, |
| | checkpoint_path, |
| | ) |
| |
|
| | logger.info(f"Saved FSDP checkpoint to {checkpoint_path}") |
| |
|
| |
|
| | def save_fsdp_checkpoint_sharded_dir( |
| | model: nn.Module, |
| | optimizer, |
| | epoch: int, |
| | checkpoint_dir: str, |
| | *, |
| | rank: int = 0, |
| | ): |
| | """ |
| | Save a sharded checkpoint directory using torch.distributed.checkpoint when available. |
| | |
| | This is the recommended path for large-scale FSDP training. |
| | """ |
| | if not isinstance(model, FSDP): |
| | |
| | ckpt_path = str(Path(checkpoint_dir) / "checkpoint.pt") |
| | torch.save( |
| | { |
| | "epoch": epoch, |
| | "model_state_dict": model.state_dict(), |
| | "optimizer_state_dict": optimizer.state_dict(), |
| | }, |
| | ckpt_path, |
| | ) |
| | return |
| |
|
| | try: |
| | import torch.distributed.checkpoint as dcp |
| | from torch.distributed.checkpoint import FileSystemWriter |
| | from torch.distributed.checkpoint.state_dict import ( |
| | get_state_dict, |
| | set_state_dict, |
| | ) |
| | except Exception: |
| | |
| | |
| | ckpt_path = str(Path(checkpoint_dir) / "checkpoint_full.pt") |
| | save_fsdp_checkpoint(model, optimizer, epoch, ckpt_path, rank=int(rank)) |
| | return |
| |
|
| | out_dir = Path(checkpoint_dir) |
| | out_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | state = get_state_dict(model, optimizer) |
| | dcp.save_state_dict( |
| | state_dict=state, |
| | storage_writer=FileSystemWriter(str(out_dir)), |
| | ) |
| | |
| | set_state_dict(model, optimizer, state) |
| |
|
| | |
| | try: |
| | import torch.distributed as dist |
| |
|
| | if dist.is_initialized(): |
| | dist.barrier() |
| | if int(rank) == 0: |
| | torch.save({"epoch": int(epoch)}, str(out_dir / "meta.pt")) |
| | (out_dir / "SUCCESS").write_text("ok\n") |
| | dist.barrier() |
| | elif int(rank) == 0: |
| | torch.save({"epoch": int(epoch)}, str(out_dir / "meta.pt")) |
| | (out_dir / "SUCCESS").write_text("ok\n") |
| | except Exception: |
| | if int(rank) == 0: |
| | torch.save({"epoch": int(epoch)}, str(out_dir / "meta.pt")) |
| | (out_dir / "SUCCESS").write_text("ok\n") |
| |
|
| |
|
| | def load_fsdp_checkpoint_sharded_dir( |
| | model: nn.Module, |
| | optimizer, |
| | checkpoint_dir: str, |
| | *, |
| | rank: int = 0, |
| | ) -> int: |
| | """ |
| | Load a sharded checkpoint directory saved by save_fsdp_checkpoint_sharded_dir(). |
| | """ |
| | if not isinstance(model, FSDP): |
| | ckpt_path = str(Path(checkpoint_dir) / "checkpoint.pt") |
| | checkpoint = torch.load(ckpt_path, map_location="cpu") |
| | model.load_state_dict(checkpoint["model_state_dict"]) |
| | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) |
| | return int(checkpoint.get("epoch", 0)) |
| |
|
| | try: |
| | import torch.distributed.checkpoint as dcp |
| | from torch.distributed.checkpoint import FileSystemReader |
| | from torch.distributed.checkpoint.state_dict import ( |
| | get_state_dict, |
| | set_state_dict, |
| | ) |
| | except Exception: |
| | |
| | ckpt_path = str(Path(checkpoint_dir) / "checkpoint_full.pt") |
| | return int(load_fsdp_checkpoint(model, optimizer, ckpt_path, rank=int(rank))) |
| |
|
| | in_dir = Path(checkpoint_dir) |
| | state = get_state_dict(model, optimizer) |
| | dcp.load_state_dict( |
| | state_dict=state, |
| | storage_reader=FileSystemReader(str(in_dir)), |
| | ) |
| | set_state_dict(model, optimizer, state) |
| |
|
| | meta_path = in_dir / "meta.pt" |
| | if meta_path.exists(): |
| | meta = torch.load(str(meta_path), map_location="cpu") |
| | return int(meta.get("epoch", 0)) |
| | return 0 |
| |
|
| |
|
| | def load_fsdp_checkpoint( |
| | model: nn.Module, |
| | optimizer, |
| | checkpoint_path: str, |
| | rank: int = 0, |
| | ): |
| | """ |
| | Load FSDP checkpoint. |
| | |
| | Args: |
| | model: FSDP-wrapped model |
| | optimizer: Optimizer |
| | checkpoint_path: Path to checkpoint |
| | rank: Process rank |
| | """ |
| | if not isinstance(model, FSDP): |
| | logger.warning("Model is not FSDP-wrapped, using standard checkpoint load") |
| | checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| | model.load_state_dict(checkpoint["model_state_dict"]) |
| | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) |
| | return checkpoint.get("epoch", 0) |
| |
|
| | |
| | try: |
| | import torch.distributed as dist |
| | except Exception: |
| | dist = None |
| |
|
| | checkpoint = None |
| | if int(rank) == 0: |
| | checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| | if dist is not None and getattr(dist, "is_initialized", lambda: False)(): |
| | obj_list = [checkpoint] |
| | dist.broadcast_object_list(obj_list, src=0) |
| | checkpoint = obj_list[0] |
| | if checkpoint is None: |
| | raise RuntimeError(f"Failed to load checkpoint: {checkpoint_path}") |
| |
|
| | |
| | from torch.distributed.fsdp import FullStateDictConfig, StateDictType |
| |
|
| | load_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
| |
|
| | with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, load_policy): |
| | model.load_state_dict(checkpoint["model_state_dict"]) |
| |
|
| | |
| | sharded_optim_state = FSDP.shard_full_optim_state_dict( |
| | checkpoint["optimizer_state_dict"], model |
| | ) |
| | optimizer.load_state_dict(sharded_optim_state) |
| |
|
| | logger.info(f"Loaded FSDP checkpoint from {checkpoint_path}") |
| | return checkpoint.get("epoch", 0) |
| |
|