File size: 3,963 Bytes
7a87926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# FSDP Integration Complete

Fully Sharded Data Parallel (FSDP) has been fully integrated into the training pipeline.

## βœ… What's Been Integrated

### 1. Service Functions

- **`fine_tune_da3()`** - FSDP support added
- **`pretrain_da3_on_arkit()`** - FSDP support added

### 2. API Endpoints

- **`/api/v1/train/start`** - FSDP parameters added to `TrainRequest`
- **`/api/v1/train/pretrain`** - FSDP parameters added to `PretrainRequest`

### 3. CLI Commands

- **`ylff train start`** - FSDP options added
- **`ylff train pretrain`** - FSDP options added

## πŸ“‹ New Parameters

### API Models

**TrainRequest & PretrainRequest**:

```python
use_fsdp: bool = False  # Enable FSDP
fsdp_sharding_strategy: str = "FULL_SHARD"  # FULL_SHARD, SHARD_GRAD_OP, NO_SHARD
fsdp_mixed_precision: Optional[str] = None  # bf16, fp16, or None (auto-detects)
```

### CLI Options

```bash
--use-fsdp                    # Enable FSDP
--fsdp-sharding-strategy      # FULL_SHARD, SHARD_GRAD_OP, NO_SHARD
--fsdp-mixed-precision        # bf16, fp16, or None
```

## πŸš€ Usage

### CLI Example

```bash
# Multi-GPU training with FSDP
torchrun --nproc_per_node=4 ylff train start data/training \
    --use-fsdp \
    --fsdp-sharding-strategy FULL_SHARD \
    --fsdp-mixed-precision bf16 \
    --use-bf16 \
    --batch-size 2
```

### API Example

```json
{
  "training_data_dir": "data/training",
  "epochs": 10,
  "use_fsdp": true,
  "fsdp_sharding_strategy": "FULL_SHARD",
  "fsdp_mixed_precision": "bf16",
  "use_bf16": true
}
```

## πŸ”§ How It Works

1. **Model Wrapping**: Before optimizer creation, the model is wrapped with FSDP if:

   - `use_fsdp=True`
   - Distributed training is initialized (`torch.distributed.is_initialized()`)

2. **Sharding Strategy**:

   - **FULL_SHARD**: Shards parameters, gradients, and optimizer states (most memory efficient)
   - **SHARD_GRAD_OP**: Shards only gradients and optimizer states
   - **NO_SHARD**: No sharding (equivalent to DDP)

3. **Mixed Precision**: Auto-detects from `use_bf16` if not specified:
   - If `use_bf16=True` β†’ uses `bf16`
   - If `use_bf16=False` β†’ uses `None` (FP32)

## πŸ“Š Benefits

- **Memory Efficiency**: Train models 2-4x larger than single GPU memory
- **Scalability**: Better memory efficiency than DDP
- **Performance**: Similar speed to DDP with better memory utilization
- **Flexibility**: Works with existing optimizations (BF16, gradient clipping, etc.)

## ⚠️ Requirements

1. **PyTorch 2.0+** with FSDP support
2. **Distributed Training**: Must initialize distributed training first:
   ```bash
   torchrun --nproc_per_node=N ...
   ```
   Or manually initialize:
   ```python
   import torch.distributed as dist
   dist.init_process_group(...)
   ```

## πŸ”„ Integration Points

### Service Functions

**Before optimizer creation**:

```python
if use_fsdp:
    if dist.is_initialized():
        model = wrap_model_fsdp(
            model,
            sharding_strategy=fsdp_sharding_strategy,
            mixed_precision=fsdp_mixed_precision or ("bf16" if use_bf16 else None),
            device_id=torch.cuda.current_device() if device == "cuda" else None,
        )
```

### Checkpoint Saving/Loading

FSDP checkpoints are handled automatically via `fsdp_utils.py`:

- Uses `FullStateDictConfig` for saving
- Gathers full state dict on rank 0
- Shards optimizer state properly

## πŸ“ Files Modified

1. **`ylff/services/fine_tune.py`** - Added FSDP wrapping
2. **`ylff/services/pretrain.py`** - Added FSDP wrapping
3. **`ylff/models/api_models.py`** - Added FSDP fields to request models
4. **`ylff/routers/training.py`** - Pass FSDP params to service functions
5. **`ylff/cli.py`** - Added FSDP CLI options

## 🎯 Next Steps

FSDP is fully integrated and ready to use! For best results:

1. Use with `torchrun` for multi-GPU training
2. Combine with BF16 for maximum memory efficiency
3. Use `FULL_SHARD` for largest models
4. Monitor GPU memory usage to verify sharding