| Implement FSDP2Strategy | |
| ======================= | |
| Overview | |
| ======== | |
| The **FSDP2Strategy** implements Fully Sharded Data Parallel (FSDP) via PyTorch's FSDP2 implementation. | |
| It enables distributed training with automatic model sharding and mixed precision support. | |
| Features | |
| ======== | |
| - Automatic model parallelism | |
| - Mixed precision training | |
| - Checkpoint management | |
| - Deferred optimizer state restoration | |
| - Device mesh initialization | |
| Initialize | |
| ========== | |
| To initialize the **FSDP2Strategy**, use the following arguments: | |
| .. code-block:: python | |
| strategy = FSDP2Strategy( | |
| data_parallel_size="auto", | |
| tensor_parallel_size="auto", | |
| checkpoint_io=None, | |
| mp_policy=None, | |
| parallelize_fn=None, | |
| **kwargs, | |
| ) | |
| Arguments: | |
| ---------- | |
| - **data_parallel_size** (*Union["auto", int]*): Number of data-parallel replicas. | |
| - **tensor_parallel_size** (*Union["auto", int]*): Number of tensor-parallel groups. | |
| - **checkpoint_io** (*optional*): Checkpoint I/O handler. | |
| - **mp_policy** (*optional*): Mixed precision policy. | |
| - **parallelize_fn** (*callable, optional*): Model parallelization function. | |
| Parallelize | |
| =========== | |
| The `parallelize()` method applies the sharding process to the model: | |
| .. code-block:: python | |
| strategy.parallelize() | |
| This method ensures that the model is only parallelized once. | |
| Environment Setup | |
| ================= | |
| The `setup_environment()` method initializes the distributed environment and device mesh: | |
| .. code-block:: python | |
| strategy.setup_environment() | |
| Manage Checkpoints | |
| ================== | |
| Save Checkpoints | |
| ---------------- | |
| The `save_checkpoint()` method unshards the checkpoint and saves it to disk: | |
| .. code-block:: python | |
| strategy.save_checkpoint(checkpoint, filepath) | |
| Load Checkpoints | |
| ---------------- | |
| The `load_checkpoint()` method loads a checkpoint from disk: | |
| .. code-block:: python | |
| checkpoint = strategy.load_checkpoint(filepath) | |
| Restore Optimizer State | |
| ======================= | |
| Optimizer state is deferred until the first training step. Use the following method to store the optimizer state: | |
| .. code-block:: python | |
| strategy.load_optimizer_state_dict(checkpoint) | |
| Train and Evaluate the Model | |
| ============================ | |
| Training Step | |
| ------------- | |
| The `training_step()` method defines a single training iteration: | |
| .. code-block:: python | |
| loss = strategy.training_step(batch, batch_idx) | |
| Validation Step | |
| --------------- | |
| The `validation_step()` method defines a validation iteration: | |
| .. code-block:: python | |
| loss = strategy.validation_step(batch, batch_idx) | |
| Test Step | |
| --------- | |
| The `test_step()` method defines a test iteration: | |
| .. code-block:: python | |
| loss = strategy.test_step(batch, batch_idx) | |
| Prediction Step | |
| --------------- | |
| The `predict_step()` method defines a prediction iteration: | |
| .. code-block:: python | |
| result = strategy.predict_step(batch, batch_idx) | |
| Process DataLoader | |
| ================== | |
| Use `process_dataloader()` to apply custom data sampling to a DataLoader: | |
| .. code-block:: python | |
| dataloader = strategy.process_dataloader(dataloader) | |
| Retrieve State Dictionary | |
| ========================= | |
| Retrieve the model's state dictionary using `lightning_module_state_dict()`: | |
| .. code-block:: python | |
| state_dict = strategy.lightning_module_state_dict() | |
| Remove Checkpoints | |
| ================== | |
| Remove a checkpoint from the filesystem: | |
| .. code-block:: python | |
| strategy.remove_checkpoint(filepath) | |
| Initialize Tensors | |
| ================== | |
| Use the `tensor_init_context()` context manager for tensor initialization: | |
| .. code-block:: python | |
| with strategy.tensor_init_context(): | |
| # Initialization code | |
| pass | |