Respair's picture
Upload folder using huggingface_hub
b386992 verified
Implement FSDP2Strategy
=======================
Overview
========
The **FSDP2Strategy** implements Fully Sharded Data Parallel (FSDP) via PyTorch's FSDP2 implementation.
It enables distributed training with automatic model sharding and mixed precision support.
Features
========
- Automatic model parallelism
- Mixed precision training
- Checkpoint management
- Deferred optimizer state restoration
- Device mesh initialization
Initialize
==========
To initialize the **FSDP2Strategy**, use the following arguments:
.. code-block:: python
strategy = FSDP2Strategy(
data_parallel_size="auto",
tensor_parallel_size="auto",
checkpoint_io=None,
mp_policy=None,
parallelize_fn=None,
**kwargs,
)
Arguments:
----------
- **data_parallel_size** (*Union["auto", int]*): Number of data-parallel replicas.
- **tensor_parallel_size** (*Union["auto", int]*): Number of tensor-parallel groups.
- **checkpoint_io** (*optional*): Checkpoint I/O handler.
- **mp_policy** (*optional*): Mixed precision policy.
- **parallelize_fn** (*callable, optional*): Model parallelization function.
Parallelize
===========
The `parallelize()` method applies the sharding process to the model:
.. code-block:: python
strategy.parallelize()
This method ensures that the model is only parallelized once.
Environment Setup
=================
The `setup_environment()` method initializes the distributed environment and device mesh:
.. code-block:: python
strategy.setup_environment()
Manage Checkpoints
==================
Save Checkpoints
----------------
The `save_checkpoint()` method unshards the checkpoint and saves it to disk:
.. code-block:: python
strategy.save_checkpoint(checkpoint, filepath)
Load Checkpoints
----------------
The `load_checkpoint()` method loads a checkpoint from disk:
.. code-block:: python
checkpoint = strategy.load_checkpoint(filepath)
Restore Optimizer State
=======================
Optimizer state is deferred until the first training step. Use the following method to store the optimizer state:
.. code-block:: python
strategy.load_optimizer_state_dict(checkpoint)
Train and Evaluate the Model
============================
Training Step
-------------
The `training_step()` method defines a single training iteration:
.. code-block:: python
loss = strategy.training_step(batch, batch_idx)
Validation Step
---------------
The `validation_step()` method defines a validation iteration:
.. code-block:: python
loss = strategy.validation_step(batch, batch_idx)
Test Step
---------
The `test_step()` method defines a test iteration:
.. code-block:: python
loss = strategy.test_step(batch, batch_idx)
Prediction Step
---------------
The `predict_step()` method defines a prediction iteration:
.. code-block:: python
result = strategy.predict_step(batch, batch_idx)
Process DataLoader
==================
Use `process_dataloader()` to apply custom data sampling to a DataLoader:
.. code-block:: python
dataloader = strategy.process_dataloader(dataloader)
Retrieve State Dictionary
=========================
Retrieve the model's state dictionary using `lightning_module_state_dict()`:
.. code-block:: python
state_dict = strategy.lightning_module_state_dict()
Remove Checkpoints
==================
Remove a checkpoint from the filesystem:
.. code-block:: python
strategy.remove_checkpoint(filepath)
Initialize Tensors
==================
Use the `tensor_init_context()` context manager for tensor initialization:
.. code-block:: python
with strategy.tensor_init_context():
# Initialization code
pass