| from functools import partial |
| import os |
| from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP, |
| FullStateDictConfig, StateDictType) |
|
|
| from torch.distributed.fsdp.wrap import (lambda_auto_wrap_policy, |
| transformer_auto_wrap_policy) |
| from wenet.LLM.decoder import DecoderOnly |
| from wenet.branchformer.encoder_layer import BranchformerEncoderLayer |
| from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer |
| from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer |
| from wenet.paraformer.layers import AliParaformerEncoderLayer, SanmDecoderLayer |
| from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer |
| from wenet.transformer.encoder_layer import (ConformerEncoderLayer, |
| TransformerEncoderLayer) |
| from wenet.transformer.decoder_layer import DecoderLayer |
| from wenet.utils.checkpoint import save_state_dict_and_infos |
| from wenet.utils.init_model import WENET_DECODER_CLASSES, WENET_ENCODER_CLASSES |
|
|
| WENET_ENCODER_LAYERS_CLASSES = { |
| 'transformer_encoder_layer': TransformerEncoderLayer, |
| 'conformer_encoder_layer': ConformerEncoderLayer, |
| 'paraformer_encoder_layer': AliParaformerEncoderLayer, |
| 'squeezeformer_encoder_layer': SqueezeformerEncoderLayer, |
| 'ebranchformer_encoder_layer': EBranchformerEncoderLayer, |
| 'efficient_conformer_encoder_layer': StrideConformerEncoderLayer, |
| 'branchformer_encoder_layer': BranchformerEncoderLayer, |
| } |
|
|
| WENET_DECODER_LAYERS_CLASSES = { |
| 'transformer_decoder_layer': DecoderLayer, |
| 'paraformer_decoder_layer': SanmDecoderLayer, |
| |
| |
| |
| } |
|
|
|
|
| def wenet_fsdp_wrap_policy(mode): |
| |
| |
| assert mode in ['no_shard', 'model', 'zero2', 'zero3'] |
| if mode == 'no_shard': |
| return None |
| else: |
| |
| |
| |
| if mode == 'model': |
| enc_dec_wrap_policy = partial( |
| lambda_auto_wrap_policy, |
| lambda_fn=lambda module: isinstance( |
| module, |
| tuple(WENET_ENCODER_CLASSES.values()) + tuple( |
| WENET_DECODER_CLASSES.values()))) |
| return enc_dec_wrap_policy |
| else: |
| to_wrap_class = set() |
| to_wrap_class.update(set(WENET_ENCODER_LAYERS_CLASSES.values())) |
| to_wrap_class.update(set(WENET_DECODER_LAYERS_CLASSES.values())) |
| layers_wrap_policy = partial(transformer_auto_wrap_policy, |
| transformer_layer_cls=to_wrap_class) |
| return layers_wrap_policy |
|
|
|
|
| fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, |
| rank0_only=True) |
|
|
|
|
| def fsdp_save_model(model, save_model_path, info_dict): |
| |
| |
| |
|
|
| rank = int(os.environ.get('RANK', 0)) |
| with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, |
| fullstate_save_policy): |
| state_dict = model.state_dict() |
| if rank == 0: |
| save_state_dict_and_infos(state_dict, save_model_path, info_dict) |
|
|
|
|
| def check_gradient_checkpoint(model): |
| ckpt_laye_types = [] |
| if hasattr(model, 'encoder') and hasattr(model.encoder, |
| 'gradient_checkpointing'): |
| if model.encoder.gradient_checkpointing: |
| model.encoder.gradient_checkpointing = False |
| ckpt_laye_types += list(WENET_ENCODER_LAYERS_CLASSES.values()) |
| if hasattr(model, 'decoder') and hasattr(model.decoder, |
| 'gradient_checkpointing'): |
| if model.decoder.gradient_checkpointing: |
| model.decoder.gradient_checkpointing = False |
| ckpt_laye_types += list(WENET_DECODER_LAYERS_CLASSES.values()) |
| if isinstance(model.decoder, DecoderOnly): |
| ckpt_laye_types += [DecoderOnly] |
| return tuple(ckpt_laye_types) |
|
|
|
|
| def apply_fsdp_checkpointing(model, ckpt_layer_types: tuple): |
| |
| |
| |
| if len(ckpt_layer_types) == 0: |
| return |
| from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
| checkpoint_wrapper, |
| CheckpointImpl, |
| apply_activation_checkpointing, |
| ) |
| non_reentrant_wrapper = partial( |
| checkpoint_wrapper, |
| checkpoint_impl=CheckpointImpl.NO_REENTRANT, |
| ) |
| apply_activation_checkpointing( |
| model, |
| checkpoint_wrapper_fn=non_reentrant_wrapper, |
| check_fn=lambda submodule: isinstance(submodule, ckpt_layer_types)) |
|
|