| # FSDP Integration Complete | |
| Fully Sharded Data Parallel (FSDP) has been fully integrated into the training pipeline. | |
| ## β What's Been Integrated | |
| ### 1. Service Functions | |
| - **`fine_tune_da3()`** - FSDP support added | |
| - **`pretrain_da3_on_arkit()`** - FSDP support added | |
| ### 2. API Endpoints | |
| - **`/api/v1/train/start`** - FSDP parameters added to `TrainRequest` | |
| - **`/api/v1/train/pretrain`** - FSDP parameters added to `PretrainRequest` | |
| ### 3. CLI Commands | |
| - **`ylff train start`** - FSDP options added | |
| - **`ylff train pretrain`** - FSDP options added | |
| ## π New Parameters | |
| ### API Models | |
| **TrainRequest & PretrainRequest**: | |
| ```python | |
| use_fsdp: bool = False # Enable FSDP | |
| fsdp_sharding_strategy: str = "FULL_SHARD" # FULL_SHARD, SHARD_GRAD_OP, NO_SHARD | |
| fsdp_mixed_precision: Optional[str] = None # bf16, fp16, or None (auto-detects) | |
| ``` | |
| ### CLI Options | |
| ```bash | |
| --use-fsdp # Enable FSDP | |
| --fsdp-sharding-strategy # FULL_SHARD, SHARD_GRAD_OP, NO_SHARD | |
| --fsdp-mixed-precision # bf16, fp16, or None | |
| ``` | |
| ## π Usage | |
| ### CLI Example | |
| ```bash | |
| # Multi-GPU training with FSDP | |
| torchrun --nproc_per_node=4 ylff train start data/training \ | |
| --use-fsdp \ | |
| --fsdp-sharding-strategy FULL_SHARD \ | |
| --fsdp-mixed-precision bf16 \ | |
| --use-bf16 \ | |
| --batch-size 2 | |
| ``` | |
| ### API Example | |
| ```json | |
| { | |
| "training_data_dir": "data/training", | |
| "epochs": 10, | |
| "use_fsdp": true, | |
| "fsdp_sharding_strategy": "FULL_SHARD", | |
| "fsdp_mixed_precision": "bf16", | |
| "use_bf16": true | |
| } | |
| ``` | |
| ## π§ How It Works | |
| 1. **Model Wrapping**: Before optimizer creation, the model is wrapped with FSDP if: | |
| - `use_fsdp=True` | |
| - Distributed training is initialized (`torch.distributed.is_initialized()`) | |
| 2. **Sharding Strategy**: | |
| - **FULL_SHARD**: Shards parameters, gradients, and optimizer states (most memory efficient) | |
| - **SHARD_GRAD_OP**: Shards only gradients and optimizer states | |
| - **NO_SHARD**: No sharding (equivalent to DDP) | |
| 3. **Mixed Precision**: Auto-detects from `use_bf16` if not specified: | |
| - If `use_bf16=True` β uses `bf16` | |
| - If `use_bf16=False` β uses `None` (FP32) | |
| ## π Benefits | |
| - **Memory Efficiency**: Train models 2-4x larger than single GPU memory | |
| - **Scalability**: Better memory efficiency than DDP | |
| - **Performance**: Similar speed to DDP with better memory utilization | |
| - **Flexibility**: Works with existing optimizations (BF16, gradient clipping, etc.) | |
| ## β οΈ Requirements | |
| 1. **PyTorch 2.0+** with FSDP support | |
| 2. **Distributed Training**: Must initialize distributed training first: | |
| ```bash | |
| torchrun --nproc_per_node=N ... | |
| ``` | |
| Or manually initialize: | |
| ```python | |
| import torch.distributed as dist | |
| dist.init_process_group(...) | |
| ``` | |
| ## π Integration Points | |
| ### Service Functions | |
| **Before optimizer creation**: | |
| ```python | |
| if use_fsdp: | |
| if dist.is_initialized(): | |
| model = wrap_model_fsdp( | |
| model, | |
| sharding_strategy=fsdp_sharding_strategy, | |
| mixed_precision=fsdp_mixed_precision or ("bf16" if use_bf16 else None), | |
| device_id=torch.cuda.current_device() if device == "cuda" else None, | |
| ) | |
| ``` | |
| ### Checkpoint Saving/Loading | |
| FSDP checkpoints are handled automatically via `fsdp_utils.py`: | |
| - Uses `FullStateDictConfig` for saving | |
| - Gathers full state dict on rank 0 | |
| - Shards optimizer state properly | |
| ## π Files Modified | |
| 1. **`ylff/services/fine_tune.py`** - Added FSDP wrapping | |
| 2. **`ylff/services/pretrain.py`** - Added FSDP wrapping | |
| 3. **`ylff/models/api_models.py`** - Added FSDP fields to request models | |
| 4. **`ylff/routers/training.py`** - Pass FSDP params to service functions | |
| 5. **`ylff/cli.py`** - Added FSDP CLI options | |
| ## π― Next Steps | |
| FSDP is fully integrated and ready to use! For best results: | |
| 1. Use with `torchrun` for multi-GPU training | |
| 2. Combine with BF16 for maximum memory efficiency | |
| 3. Use `FULL_SHARD` for largest models | |
| 4. Monitor GPU memory usage to verify sharding | |