3d_model / docs /FSDP_INTEGRATION.md
Azan
Clean deployment build (Squashed)
7a87926

FSDP Integration Complete

Fully Sharded Data Parallel (FSDP) has been fully integrated into the training pipeline.

βœ… What's Been Integrated

1. Service Functions

  • fine_tune_da3() - FSDP support added
  • pretrain_da3_on_arkit() - FSDP support added

2. API Endpoints

  • /api/v1/train/start - FSDP parameters added to TrainRequest
  • /api/v1/train/pretrain - FSDP parameters added to PretrainRequest

3. CLI Commands

  • ylff train start - FSDP options added
  • ylff train pretrain - FSDP options added

πŸ“‹ New Parameters

API Models

TrainRequest & PretrainRequest:

use_fsdp: bool = False  # Enable FSDP
fsdp_sharding_strategy: str = "FULL_SHARD"  # FULL_SHARD, SHARD_GRAD_OP, NO_SHARD
fsdp_mixed_precision: Optional[str] = None  # bf16, fp16, or None (auto-detects)

CLI Options

--use-fsdp                    # Enable FSDP
--fsdp-sharding-strategy      # FULL_SHARD, SHARD_GRAD_OP, NO_SHARD
--fsdp-mixed-precision        # bf16, fp16, or None

πŸš€ Usage

CLI Example

# Multi-GPU training with FSDP
torchrun --nproc_per_node=4 ylff train start data/training \
    --use-fsdp \
    --fsdp-sharding-strategy FULL_SHARD \
    --fsdp-mixed-precision bf16 \
    --use-bf16 \
    --batch-size 2

API Example

{
  "training_data_dir": "data/training",
  "epochs": 10,
  "use_fsdp": true,
  "fsdp_sharding_strategy": "FULL_SHARD",
  "fsdp_mixed_precision": "bf16",
  "use_bf16": true
}

πŸ”§ How It Works

  1. Model Wrapping: Before optimizer creation, the model is wrapped with FSDP if:

    • use_fsdp=True
    • Distributed training is initialized (torch.distributed.is_initialized())
  2. Sharding Strategy:

    • FULL_SHARD: Shards parameters, gradients, and optimizer states (most memory efficient)
    • SHARD_GRAD_OP: Shards only gradients and optimizer states
    • NO_SHARD: No sharding (equivalent to DDP)
  3. Mixed Precision: Auto-detects from use_bf16 if not specified:

    • If use_bf16=True β†’ uses bf16
    • If use_bf16=False β†’ uses None (FP32)

πŸ“Š Benefits

  • Memory Efficiency: Train models 2-4x larger than single GPU memory
  • Scalability: Better memory efficiency than DDP
  • Performance: Similar speed to DDP with better memory utilization
  • Flexibility: Works with existing optimizations (BF16, gradient clipping, etc.)

⚠️ Requirements

  1. PyTorch 2.0+ with FSDP support
  2. Distributed Training: Must initialize distributed training first:
    torchrun --nproc_per_node=N ...
    
    Or manually initialize:
    import torch.distributed as dist
    dist.init_process_group(...)
    

πŸ”„ Integration Points

Service Functions

Before optimizer creation:

if use_fsdp:
    if dist.is_initialized():
        model = wrap_model_fsdp(
            model,
            sharding_strategy=fsdp_sharding_strategy,
            mixed_precision=fsdp_mixed_precision or ("bf16" if use_bf16 else None),
            device_id=torch.cuda.current_device() if device == "cuda" else None,
        )

Checkpoint Saving/Loading

FSDP checkpoints are handled automatically via fsdp_utils.py:

  • Uses FullStateDictConfig for saving
  • Gathers full state dict on rank 0
  • Shards optimizer state properly

πŸ“ Files Modified

  1. ylff/services/fine_tune.py - Added FSDP wrapping
  2. ylff/services/pretrain.py - Added FSDP wrapping
  3. ylff/models/api_models.py - Added FSDP fields to request models
  4. ylff/routers/training.py - Pass FSDP params to service functions
  5. ylff/cli.py - Added FSDP CLI options

🎯 Next Steps

FSDP is fully integrated and ready to use! For best results:

  1. Use with torchrun for multi-GPU training
  2. Combine with BF16 for maximum memory efficiency
  3. Use FULL_SHARD for largest models
  4. Monitor GPU memory usage to verify sharding