3d_model / docs /FSDP_INTEGRATION.md
Azan
Clean deployment build (Squashed)
7a87926
# FSDP Integration Complete
Fully Sharded Data Parallel (FSDP) has been fully integrated into the training pipeline.
## βœ… What's Been Integrated
### 1. Service Functions
- **`fine_tune_da3()`** - FSDP support added
- **`pretrain_da3_on_arkit()`** - FSDP support added
### 2. API Endpoints
- **`/api/v1/train/start`** - FSDP parameters added to `TrainRequest`
- **`/api/v1/train/pretrain`** - FSDP parameters added to `PretrainRequest`
### 3. CLI Commands
- **`ylff train start`** - FSDP options added
- **`ylff train pretrain`** - FSDP options added
## πŸ“‹ New Parameters
### API Models
**TrainRequest & PretrainRequest**:
```python
use_fsdp: bool = False # Enable FSDP
fsdp_sharding_strategy: str = "FULL_SHARD" # FULL_SHARD, SHARD_GRAD_OP, NO_SHARD
fsdp_mixed_precision: Optional[str] = None # bf16, fp16, or None (auto-detects)
```
### CLI Options
```bash
--use-fsdp # Enable FSDP
--fsdp-sharding-strategy # FULL_SHARD, SHARD_GRAD_OP, NO_SHARD
--fsdp-mixed-precision # bf16, fp16, or None
```
## πŸš€ Usage
### CLI Example
```bash
# Multi-GPU training with FSDP
torchrun --nproc_per_node=4 ylff train start data/training \
--use-fsdp \
--fsdp-sharding-strategy FULL_SHARD \
--fsdp-mixed-precision bf16 \
--use-bf16 \
--batch-size 2
```
### API Example
```json
{
"training_data_dir": "data/training",
"epochs": 10,
"use_fsdp": true,
"fsdp_sharding_strategy": "FULL_SHARD",
"fsdp_mixed_precision": "bf16",
"use_bf16": true
}
```
## πŸ”§ How It Works
1. **Model Wrapping**: Before optimizer creation, the model is wrapped with FSDP if:
- `use_fsdp=True`
- Distributed training is initialized (`torch.distributed.is_initialized()`)
2. **Sharding Strategy**:
- **FULL_SHARD**: Shards parameters, gradients, and optimizer states (most memory efficient)
- **SHARD_GRAD_OP**: Shards only gradients and optimizer states
- **NO_SHARD**: No sharding (equivalent to DDP)
3. **Mixed Precision**: Auto-detects from `use_bf16` if not specified:
- If `use_bf16=True` β†’ uses `bf16`
- If `use_bf16=False` β†’ uses `None` (FP32)
## πŸ“Š Benefits
- **Memory Efficiency**: Train models 2-4x larger than single GPU memory
- **Scalability**: Better memory efficiency than DDP
- **Performance**: Similar speed to DDP with better memory utilization
- **Flexibility**: Works with existing optimizations (BF16, gradient clipping, etc.)
## ⚠️ Requirements
1. **PyTorch 2.0+** with FSDP support
2. **Distributed Training**: Must initialize distributed training first:
```bash
torchrun --nproc_per_node=N ...
```
Or manually initialize:
```python
import torch.distributed as dist
dist.init_process_group(...)
```
## πŸ”„ Integration Points
### Service Functions
**Before optimizer creation**:
```python
if use_fsdp:
if dist.is_initialized():
model = wrap_model_fsdp(
model,
sharding_strategy=fsdp_sharding_strategy,
mixed_precision=fsdp_mixed_precision or ("bf16" if use_bf16 else None),
device_id=torch.cuda.current_device() if device == "cuda" else None,
)
```
### Checkpoint Saving/Loading
FSDP checkpoints are handled automatically via `fsdp_utils.py`:
- Uses `FullStateDictConfig` for saving
- Gathers full state dict on rank 0
- Shards optimizer state properly
## πŸ“ Files Modified
1. **`ylff/services/fine_tune.py`** - Added FSDP wrapping
2. **`ylff/services/pretrain.py`** - Added FSDP wrapping
3. **`ylff/models/api_models.py`** - Added FSDP fields to request models
4. **`ylff/routers/training.py`** - Pass FSDP params to service functions
5. **`ylff/cli.py`** - Added FSDP CLI options
## 🎯 Next Steps
FSDP is fully integrated and ready to use! For best results:
1. Use with `torchrun` for multi-GPU training
2. Combine with BF16 for maximum memory efficiency
3. Use `FULL_SHARD` for largest models
4. Monitor GPU memory usage to verify sharding