File size: 3,963 Bytes
7a87926 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# FSDP Integration Complete
Fully Sharded Data Parallel (FSDP) has been fully integrated into the training pipeline.
## β
What's Been Integrated
### 1. Service Functions
- **`fine_tune_da3()`** - FSDP support added
- **`pretrain_da3_on_arkit()`** - FSDP support added
### 2. API Endpoints
- **`/api/v1/train/start`** - FSDP parameters added to `TrainRequest`
- **`/api/v1/train/pretrain`** - FSDP parameters added to `PretrainRequest`
### 3. CLI Commands
- **`ylff train start`** - FSDP options added
- **`ylff train pretrain`** - FSDP options added
## π New Parameters
### API Models
**TrainRequest & PretrainRequest**:
```python
use_fsdp: bool = False # Enable FSDP
fsdp_sharding_strategy: str = "FULL_SHARD" # FULL_SHARD, SHARD_GRAD_OP, NO_SHARD
fsdp_mixed_precision: Optional[str] = None # bf16, fp16, or None (auto-detects)
```
### CLI Options
```bash
--use-fsdp # Enable FSDP
--fsdp-sharding-strategy # FULL_SHARD, SHARD_GRAD_OP, NO_SHARD
--fsdp-mixed-precision # bf16, fp16, or None
```
## π Usage
### CLI Example
```bash
# Multi-GPU training with FSDP
torchrun --nproc_per_node=4 ylff train start data/training \
--use-fsdp \
--fsdp-sharding-strategy FULL_SHARD \
--fsdp-mixed-precision bf16 \
--use-bf16 \
--batch-size 2
```
### API Example
```json
{
"training_data_dir": "data/training",
"epochs": 10,
"use_fsdp": true,
"fsdp_sharding_strategy": "FULL_SHARD",
"fsdp_mixed_precision": "bf16",
"use_bf16": true
}
```
## π§ How It Works
1. **Model Wrapping**: Before optimizer creation, the model is wrapped with FSDP if:
- `use_fsdp=True`
- Distributed training is initialized (`torch.distributed.is_initialized()`)
2. **Sharding Strategy**:
- **FULL_SHARD**: Shards parameters, gradients, and optimizer states (most memory efficient)
- **SHARD_GRAD_OP**: Shards only gradients and optimizer states
- **NO_SHARD**: No sharding (equivalent to DDP)
3. **Mixed Precision**: Auto-detects from `use_bf16` if not specified:
- If `use_bf16=True` β uses `bf16`
- If `use_bf16=False` β uses `None` (FP32)
## π Benefits
- **Memory Efficiency**: Train models 2-4x larger than single GPU memory
- **Scalability**: Better memory efficiency than DDP
- **Performance**: Similar speed to DDP with better memory utilization
- **Flexibility**: Works with existing optimizations (BF16, gradient clipping, etc.)
## β οΈ Requirements
1. **PyTorch 2.0+** with FSDP support
2. **Distributed Training**: Must initialize distributed training first:
```bash
torchrun --nproc_per_node=N ...
```
Or manually initialize:
```python
import torch.distributed as dist
dist.init_process_group(...)
```
## π Integration Points
### Service Functions
**Before optimizer creation**:
```python
if use_fsdp:
if dist.is_initialized():
model = wrap_model_fsdp(
model,
sharding_strategy=fsdp_sharding_strategy,
mixed_precision=fsdp_mixed_precision or ("bf16" if use_bf16 else None),
device_id=torch.cuda.current_device() if device == "cuda" else None,
)
```
### Checkpoint Saving/Loading
FSDP checkpoints are handled automatically via `fsdp_utils.py`:
- Uses `FullStateDictConfig` for saving
- Gathers full state dict on rank 0
- Shards optimizer state properly
## π Files Modified
1. **`ylff/services/fine_tune.py`** - Added FSDP wrapping
2. **`ylff/services/pretrain.py`** - Added FSDP wrapping
3. **`ylff/models/api_models.py`** - Added FSDP fields to request models
4. **`ylff/routers/training.py`** - Pass FSDP params to service functions
5. **`ylff/cli.py`** - Added FSDP CLI options
## π― Next Steps
FSDP is fully integrated and ready to use! For best results:
1. Use with `torchrun` for multi-GPU training
2. Combine with BF16 for maximum memory efficiency
3. Use `FULL_SHARD` for largest models
4. Monitor GPU memory usage to verify sharding
|