File size: 3,638 Bytes
b386992 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | Implement FSDP2Strategy
=======================
Overview
========
The **FSDP2Strategy** implements Fully Sharded Data Parallel (FSDP) via PyTorch's FSDP2 implementation.
It enables distributed training with automatic model sharding and mixed precision support.
Features
========
- Automatic model parallelism
- Mixed precision training
- Checkpoint management
- Deferred optimizer state restoration
- Device mesh initialization
Initialize
==========
To initialize the **FSDP2Strategy**, use the following arguments:
.. code-block:: python
strategy = FSDP2Strategy(
data_parallel_size="auto",
tensor_parallel_size="auto",
checkpoint_io=None,
mp_policy=None,
parallelize_fn=None,
**kwargs,
)
Arguments:
----------
- **data_parallel_size** (*Union["auto", int]*): Number of data-parallel replicas.
- **tensor_parallel_size** (*Union["auto", int]*): Number of tensor-parallel groups.
- **checkpoint_io** (*optional*): Checkpoint I/O handler.
- **mp_policy** (*optional*): Mixed precision policy.
- **parallelize_fn** (*callable, optional*): Model parallelization function.
Parallelize
===========
The `parallelize()` method applies the sharding process to the model:
.. code-block:: python
strategy.parallelize()
This method ensures that the model is only parallelized once.
Environment Setup
=================
The `setup_environment()` method initializes the distributed environment and device mesh:
.. code-block:: python
strategy.setup_environment()
Manage Checkpoints
==================
Save Checkpoints
----------------
The `save_checkpoint()` method unshards the checkpoint and saves it to disk:
.. code-block:: python
strategy.save_checkpoint(checkpoint, filepath)
Load Checkpoints
----------------
The `load_checkpoint()` method loads a checkpoint from disk:
.. code-block:: python
checkpoint = strategy.load_checkpoint(filepath)
Restore Optimizer State
=======================
Optimizer state is deferred until the first training step. Use the following method to store the optimizer state:
.. code-block:: python
strategy.load_optimizer_state_dict(checkpoint)
Train and Evaluate the Model
============================
Training Step
-------------
The `training_step()` method defines a single training iteration:
.. code-block:: python
loss = strategy.training_step(batch, batch_idx)
Validation Step
---------------
The `validation_step()` method defines a validation iteration:
.. code-block:: python
loss = strategy.validation_step(batch, batch_idx)
Test Step
---------
The `test_step()` method defines a test iteration:
.. code-block:: python
loss = strategy.test_step(batch, batch_idx)
Prediction Step
---------------
The `predict_step()` method defines a prediction iteration:
.. code-block:: python
result = strategy.predict_step(batch, batch_idx)
Process DataLoader
==================
Use `process_dataloader()` to apply custom data sampling to a DataLoader:
.. code-block:: python
dataloader = strategy.process_dataloader(dataloader)
Retrieve State Dictionary
=========================
Retrieve the model's state dictionary using `lightning_module_state_dict()`:
.. code-block:: python
state_dict = strategy.lightning_module_state_dict()
Remove Checkpoints
==================
Remove a checkpoint from the filesystem:
.. code-block:: python
strategy.remove_checkpoint(filepath)
Initialize Tensors
==================
Use the `tensor_init_context()` context manager for tensor initialization:
.. code-block:: python
with strategy.tensor_init_context():
# Initialization code
pass
|