| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Wrapper around FSDP for more convenient use in the training loops. |
| | """ |
| |
|
| | from contextlib import contextmanager |
| | import typing as tp |
| | import dora |
| | import torch |
| |
|
| | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| | from torch.distributed.fsdp import ( |
| | MixedPrecision, ShardingStrategy, FullStateDictConfig, StateDictType) |
| | from torch.distributed._shard.sharded_tensor.api import ShardedTensor |
| |
|
| |
|
| | def is_fsdp_used() -> bool: |
| | """Return whether we are using FSDP.""" |
| | |
| | if dora.is_xp(): |
| | cfg = dora.get_xp().cfg |
| | if hasattr(cfg, 'fsdp'): |
| | return cfg.fsdp.use |
| | return False |
| |
|
| |
|
| | def is_sharded_tensor(x: tp.Any) -> bool: |
| | return isinstance(x, ShardedTensor) |
| |
|
| |
|
| | @contextmanager |
| | def switch_to_full_state_dict(models: tp.List[FSDP]): |
| | |
| | |
| | for model in models: |
| | FSDP.set_state_dict_type( |
| | model, StateDictType.FULL_STATE_DICT, |
| | FullStateDictConfig(offload_to_cpu=True, rank0_only=True)) |
| | try: |
| | yield |
| | finally: |
| | for model in models: |
| | FSDP.set_state_dict_type(model, StateDictType.LOCAL_STATE_DICT) |
| |
|
| |
|
| | def wrap_with_fsdp(cfg, model: torch.nn.Module, |
| | block_classes: tp.Optional[tp.Set[tp.Type]] = None) -> FSDP: |
| | """Wraps a model with FSDP.""" |
| | |
| | |
| | from torch.distributed.fsdp.wrap import ModuleWrapPolicy |
| |
|
| | |
| | from ..modules.transformer import StreamingTransformerLayer |
| | from ..modules.conditioners import ConditioningProvider |
| |
|
| | _fix_post_backward_hook() |
| |
|
| | assert cfg.use |
| | sharding_strategy_dict = { |
| | "no_shard": ShardingStrategy.NO_SHARD, |
| | "shard_grad_op": ShardingStrategy.SHARD_GRAD_OP, |
| | "full_shard": ShardingStrategy.FULL_SHARD, |
| | } |
| |
|
| | dtype_dict = { |
| | "float32": torch.float32, |
| | "float16": torch.float16, |
| | "bfloat16": torch.bfloat16, |
| | } |
| |
|
| | mixed_precision_config = MixedPrecision( |
| | param_dtype=dtype_dict[cfg.param_dtype], |
| | reduce_dtype=dtype_dict[cfg.reduce_dtype], |
| | buffer_dtype=dtype_dict[cfg.buffer_dtype], |
| | ) |
| |
|
| | sharding_strategy_config = sharding_strategy_dict[cfg.sharding_strategy] |
| | |
| | |
| | |
| | |
| | assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \ |
| | "Not supported at the moment, requires a bit more work." |
| |
|
| | local_rank = dora.distrib.get_distrib_spec().local_rank |
| | assert local_rank < torch.cuda.device_count(), "Please upgrade Dora!" |
| |
|
| | auto_wrap_policy = None |
| | if block_classes is None: |
| | block_classes = {StreamingTransformerLayer, ConditioningProvider} |
| | if cfg.per_block: |
| | auto_wrap_policy = ModuleWrapPolicy(block_classes) |
| | wrapped = _FSDPFixStateDict( |
| | model, |
| | sharding_strategy=sharding_strategy_config, |
| | mixed_precision=mixed_precision_config, |
| | device_id=local_rank, |
| | sync_module_states=True, |
| | use_orig_params=True, |
| | auto_wrap_policy=auto_wrap_policy, |
| | ) |
| | FSDP.set_state_dict_type(wrapped, StateDictType.LOCAL_STATE_DICT) |
| |
|
| | |
| | |
| | |
| | |
| | for module in FSDP.fsdp_modules(wrapped): |
| | original = module._fsdp_wrapped_module |
| | original.__dict__['_fsdp'] = module |
| | return wrapped |
| |
|
| |
|
| | def purge_fsdp(model: FSDP): |
| | """Purge the FSDP cached shard inside the model. This should |
| | allow setting the best state or switching to the EMA. |
| | """ |
| | from torch.distributed.fsdp._runtime_utils import _reshard |
| | for module in FSDP.fsdp_modules(model): |
| | handles = module._handles |
| | if not handles: |
| | continue |
| | handle = handles[0] |
| | unsharded_flat_param = handle._get_padded_unsharded_flat_param() |
| | storage_size: int = unsharded_flat_param._typed_storage()._size() |
| | if storage_size == 0: |
| | continue |
| | true_list = [True for h in handles] |
| | _reshard(module, handles, true_list) |
| |
|
| |
|
| | class _FSDPFixStateDict(FSDP): |
| | @staticmethod |
| | def _name_without_fsdp_prefix(name: str) -> str: |
| | from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE |
| | parts = name.split('.') |
| | new_parts = [part for part in parts if part != FSDP_WRAPPED_MODULE] |
| | return '.'.join(new_parts) |
| |
|
| | def state_dict(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: |
| | state = dict(super().state_dict(*args, **kwargs)) |
| | for key, value in list(state.items()): |
| | if is_sharded_tensor(value): |
| | del state[key] |
| | return state |
| |
|
| | def load_state_dict(self, state: tp.Dict[str, tp.Any]): |
| | if self._state_dict_type is StateDictType.FULL_STATE_DICT: |
| | super().load_state_dict(state) |
| | purge_fsdp(self) |
| | return |
| | |
| | |
| | current_state = dict(super().state_dict()) |
| | for key, value in state.items(): |
| | key = _FSDPFixStateDict._name_without_fsdp_prefix(key) |
| | if key not in current_state: |
| | |
| | raise RuntimeError(f"Unknown state key {key}") |
| | current_state[key].copy_(value) |
| |
|
| | |
| | purge_fsdp(self) |
| |
|
| |
|
| | _hook_fixed = False |
| |
|
| |
|
| | def _fix_post_backward_hook(): |
| | global _hook_fixed |
| | if _hook_fixed: |
| | return |
| | _hook_fixed = True |
| |
|
| | from torch.distributed.fsdp import _runtime_utils |
| | from torch.distributed.fsdp._common_utils import TrainingState, HandleTrainingState |
| | old_hook = _runtime_utils._post_backward_hook |
| |
|
| | def _post_backward_hook(state, handle, *args, **kwargs): |
| | checkpointed = getattr(state._fsdp_wrapped_module, '_audiocraft_checkpointed', False) |
| | if checkpointed: |
| | |
| | |
| | |
| | state.training_state = TrainingState.FORWARD_BACKWARD |
| | handle._training_state = HandleTrainingState.BACKWARD_PRE |
| | old_hook(state, handle, *args, **kwargs) |
| |
|
| | _runtime_utils._post_backward_hook = _post_backward_hook |
| |
|