|
|
|
|
|
|
|
|
|
|
| import logging
|
| from contextlib import suppress
|
| from functools import partial
|
|
|
| import torch
|
| import torch.nn as nn
|
| from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
|
| from torch.distributed.device_mesh import DeviceMesh
|
| from torch.distributed.fsdp import register_fsdp_forward_method
|
| from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState
|
| from torch.utils.checkpoint import create_selective_checkpoint_contexts
|
|
|
| logger = logging.getLogger("dinov3")
|
|
|
|
|
| def map_modules_and_blocks(models: list[nn.Module], callable) -> None:
|
| for m in models:
|
| for block_id, block in enumerate(m.blocks):
|
| m.blocks[block_id] = callable(block, is_backbone_block=True)
|
|
|
|
|
| def ac_compile_parallelize_and_init(
|
| clip_model: nn.Module,
|
| world_mesh: DeviceMesh,
|
| do_compile: bool,
|
| use_activation_checkpointing: bool,
|
| use_full_activation_checkpointing: bool,
|
| use_cuda_graphs: bool,
|
| param_dtype_str: str = "bf16",
|
| reduce_dtype_str: str = "fp32",
|
| ) -> None:
|
| """
|
| Order of the wrappers:
|
| 1/ Activation checkpointing on blocks
|
| 2/ Compile blocks
|
| 3/ FSDP blocks + global model
|
| """
|
| logger.info("DISTRIBUTED FSDP -- preparing model for distributed training")
|
|
|
|
|
| from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
| checkpoint_wrapper,
|
| )
|
|
|
| trained_models = []
|
| inference_only_models = []
|
| for model in [clip_model.visual_model, clip_model.text_model]:
|
| if not model.freeze_backbone:
|
| trained_models.append(model.backbone)
|
| else:
|
| inference_only_models.append(model.backbone)
|
| trained_models.append(model.head)
|
|
|
| for model in trained_models:
|
| if use_activation_checkpointing:
|
| if use_full_activation_checkpointing:
|
| _checkpointing_wrapper = checkpoint_wrapper
|
| logger.info(
|
| "using selective checkpointing on backbone with full checkpointing policy"
|
| )
|
| else:
|
| _save_list = [
|
|
|
| torch.ops.aten.mm.default,
|
| torch.ops.aten._scaled_mm.default,
|
|
|
| torch.ops.aten._scaled_dot_product_efficient_attention.default,
|
| torch.ops.aten._scaled_dot_product_flash_attention.default,
|
| torch.ops._c10d_functional.reduce_scatter_tensor.default,
|
| ]
|
| with suppress(
|
| AttributeError
|
| ):
|
| _save_list.append(torch.ops.xformers_flash3.flash_fwd.default)
|
| _checkpointing_wrapper = partial(
|
| checkpoint_wrapper,
|
| context_fn=partial(
|
| create_selective_checkpoint_contexts, _save_list
|
| ),
|
| preserve_rng_state=True,
|
| )
|
| logger.info(
|
| "using selective checkpointing on backbone with selective policy"
|
| )
|
| for i, b in enumerate(model.blocks):
|
| if not isinstance(b, nn.Identity):
|
| model.blocks[i] = _checkpointing_wrapper(b)
|
|
|
|
|
| def compile_block(block: nn.Module) -> nn.Module:
|
| if do_compile:
|
| if use_cuda_graphs:
|
| block.compile(
|
| fullgraph=True, dynamic=False, options={"triton.cudagraphs": True}
|
| )
|
| else:
|
| block.compile()
|
| return block
|
|
|
| def compile_backbone(backbone: nn.Module) -> nn.Module:
|
| for block_id, block in enumerate(backbone.blocks):
|
| backbone.blocks[block_id] = compile_block(block)
|
|
|
| def compile_head(head: nn.Module) -> nn.Module:
|
| for block_id in range(head.num_blocks):
|
| head.blocks[block_id] = compile_block(head.blocks[block_id])
|
| if do_compile and isinstance(head.linear_projection, nn.Linear):
|
| head.linear_projection.compile()
|
|
|
| compile_backbone(clip_model.visual_model.backbone)
|
| compile_backbone(clip_model.text_model.backbone)
|
| compile_head(clip_model.visual_model.head)
|
| compile_head(clip_model.text_model.head)
|
| DTYPE_MAP = {
|
| "fp16": torch.float16,
|
| "bf16": torch.bfloat16,
|
| "fp32": torch.float32,
|
| }
|
| mp_policy = MixedPrecisionPolicy(
|
| param_dtype=DTYPE_MAP[param_dtype_str],
|
| reduce_dtype=DTYPE_MAP[reduce_dtype_str],
|
| )
|
| fsdp_config = {"mesh": world_mesh["dp"], "mp_policy": mp_policy}
|
|
|
| for block in clip_model.visual_model.backbone.blocks:
|
| fully_shard(block, **fsdp_config, reshard_after_forward=True)
|
| for i in range(clip_model.visual_model.head.num_blocks):
|
| fully_shard(
|
| clip_model.visual_model.head.blocks[i],
|
| **fsdp_config,
|
| reshard_after_forward=True,
|
| )
|
| fully_shard(
|
| clip_model.visual_model.head.linear_projection,
|
| **fsdp_config,
|
| reshard_after_forward=True,
|
| )
|
| fully_shard(
|
| clip_model.visual_model.backbone, **fsdp_config, reshard_after_forward=True
|
| )
|
| fully_shard(clip_model.visual_model.head, **fsdp_config, reshard_after_forward=True)
|
| register_fsdp_forward_method(
|
| clip_model.visual_model.backbone, "get_intermediate_layers"
|
| )
|
| for block in clip_model.text_model.backbone.blocks:
|
| fully_shard(block, **fsdp_config, reshard_after_forward=True)
|
| for i in range(clip_model.text_model.head.num_blocks):
|
| fully_shard(
|
| clip_model.text_model.head.blocks[i],
|
| **fsdp_config,
|
| reshard_after_forward=True,
|
| )
|
| fully_shard(
|
| clip_model.text_model.head.linear_projection,
|
| **fsdp_config,
|
| reshard_after_forward=True,
|
| )
|
| fully_shard(
|
| clip_model.text_model.backbone, **fsdp_config, reshard_after_forward=True
|
| )
|
| fully_shard(clip_model.text_model.head, **fsdp_config, reshard_after_forward=True)
|
|
|
| clip_model.to_empty(device="cuda")
|
| clip_model.init_weights()
|
|
|
| for model in inference_only_models:
|
| fsdp_state: FSDPState = model._get_fsdp_state()
|
| if not fsdp_state._fsdp_param_group:
|
| continue
|
| mi = fsdp_state._fsdp_param_group.post_forward_mesh_info
|
| fsdp_state._lazy_init()
|
| fsdp_state._fsdp_param_group.post_forward_mesh_info = mi
|
|
|