FSDP Integration Complete
Fully Sharded Data Parallel (FSDP) has been fully integrated into the training pipeline.
β What's Been Integrated
1. Service Functions
fine_tune_da3()- FSDP support addedpretrain_da3_on_arkit()- FSDP support added
2. API Endpoints
/api/v1/train/start- FSDP parameters added toTrainRequest/api/v1/train/pretrain- FSDP parameters added toPretrainRequest
3. CLI Commands
ylff train start- FSDP options addedylff train pretrain- FSDP options added
π New Parameters
API Models
TrainRequest & PretrainRequest:
use_fsdp: bool = False # Enable FSDP
fsdp_sharding_strategy: str = "FULL_SHARD" # FULL_SHARD, SHARD_GRAD_OP, NO_SHARD
fsdp_mixed_precision: Optional[str] = None # bf16, fp16, or None (auto-detects)
CLI Options
--use-fsdp # Enable FSDP
--fsdp-sharding-strategy # FULL_SHARD, SHARD_GRAD_OP, NO_SHARD
--fsdp-mixed-precision # bf16, fp16, or None
π Usage
CLI Example
# Multi-GPU training with FSDP
torchrun --nproc_per_node=4 ylff train start data/training \
--use-fsdp \
--fsdp-sharding-strategy FULL_SHARD \
--fsdp-mixed-precision bf16 \
--use-bf16 \
--batch-size 2
API Example
{
"training_data_dir": "data/training",
"epochs": 10,
"use_fsdp": true,
"fsdp_sharding_strategy": "FULL_SHARD",
"fsdp_mixed_precision": "bf16",
"use_bf16": true
}
π§ How It Works
Model Wrapping: Before optimizer creation, the model is wrapped with FSDP if:
use_fsdp=True- Distributed training is initialized (
torch.distributed.is_initialized())
Sharding Strategy:
- FULL_SHARD: Shards parameters, gradients, and optimizer states (most memory efficient)
- SHARD_GRAD_OP: Shards only gradients and optimizer states
- NO_SHARD: No sharding (equivalent to DDP)
Mixed Precision: Auto-detects from
use_bf16if not specified:- If
use_bf16=Trueβ usesbf16 - If
use_bf16=Falseβ usesNone(FP32)
- If
π Benefits
- Memory Efficiency: Train models 2-4x larger than single GPU memory
- Scalability: Better memory efficiency than DDP
- Performance: Similar speed to DDP with better memory utilization
- Flexibility: Works with existing optimizations (BF16, gradient clipping, etc.)
β οΈ Requirements
- PyTorch 2.0+ with FSDP support
- Distributed Training: Must initialize distributed training first:
Or manually initialize:torchrun --nproc_per_node=N ...import torch.distributed as dist dist.init_process_group(...)
π Integration Points
Service Functions
Before optimizer creation:
if use_fsdp:
if dist.is_initialized():
model = wrap_model_fsdp(
model,
sharding_strategy=fsdp_sharding_strategy,
mixed_precision=fsdp_mixed_precision or ("bf16" if use_bf16 else None),
device_id=torch.cuda.current_device() if device == "cuda" else None,
)
Checkpoint Saving/Loading
FSDP checkpoints are handled automatically via fsdp_utils.py:
- Uses
FullStateDictConfigfor saving - Gathers full state dict on rank 0
- Shards optimizer state properly
π Files Modified
ylff/services/fine_tune.py- Added FSDP wrappingylff/services/pretrain.py- Added FSDP wrappingylff/models/api_models.py- Added FSDP fields to request modelsylff/routers/training.py- Pass FSDP params to service functionsylff/cli.py- Added FSDP CLI options
π― Next Steps
FSDP is fully integrated and ready to use! For best results:
- Use with
torchrunfor multi-GPU training - Combine with BF16 for maximum memory efficiency
- Use
FULL_SHARDfor largest models - Monitor GPU memory usage to verify sharding