File size: 8,402 Bytes
857c2e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
from transformers.trainer_pt_utils import get_module_class_from_name
import torch
import torch.nn as nn
from verl.utils.vla_utils.openvla_oft.modeling_prismatic import PrismaticProjector
def init_fn(x: torch.nn.Module):
if not torch.distributed.get_rank() == 0:
x = x.to_empty(device=torch.cuda.current_device(), recurse=False)
torch.cuda.empty_cache()
return x
def get_init_weight_context_manager(use_meta_tensor=True):
from accelerate import init_empty_weights
cpu_init_weights = lambda: torch.device('cpu')
if use_meta_tensor:
init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights
else:
init_context = cpu_init_weights
return init_context
# Copyright 2020-present the HuggingFace Inc. team.
# Adapted from https://github.com/huggingface/transformers/src/transformers/trainer.py
def get_fsdp_wrap_policy(module, config=None):
if config is None:
config = {}
if config.get('disable', False):
return None
default_transformer_cls_names_to_wrap = getattr(module, "_no_split_modules", None)
fsdp_transformer_layer_cls_to_wrap = config.get("transformer_layer_cls_to_wrap",
default_transformer_cls_names_to_wrap)
min_num_params = config.get('min_num_params', 0)
auto_wrap_policy = None
if min_num_params > 0:
auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params)
elif fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = set()
for layer_class in fsdp_transformer_layer_cls_to_wrap:
transformer_cls = get_module_class_from_name(module, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
return auto_wrap_policy
def get_fsdp_wrap_policy_vla(module, config=None, is_lora=False):
from timm.models.vision_transformer import Block, VisionTransformer
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy, lambda_auto_wrap_policy
vit_wrap_policy = functools.partial(_module_wrap_policy, module_classes={VisionTransformer})
transformer_block_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
vision_fsdp_wrapping_policy = functools.partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])
# transformer_block_policy = functools.partial(
# transformer_auto_wrap_policy, transformer_layer_cls={self.transformer_layer_cls}
# )\
#default_transformer_cls_names_to_wrap = getattr(module, "_no_split_modules", None)
default_transformer_cls_names_to_wrap = getattr(module.language_model, "_no_split_modules", None)
fsdp_transformer_layer_cls_to_wrap = default_transformer_cls_names_to_wrap
llm_wrap_policy = None
if fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = set()
for layer_class in fsdp_transformer_layer_cls_to_wrap:
print("layer_class is :", layer_class)
transformer_cls = get_module_class_from_name(module, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
llm_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
print("llm_wrap_policy:",llm_wrap_policy)
assert llm_wrap_policy is not None
# Get Prismatic Wrapping Policy =>> just a module wrapping policy around `self.projector`
# prismatic_fsdp_wrapping_policy = functools.partial(
# _module_wrap_policy,
# module_classes={LinearProjector, MLPProjector, FusedMLPProjector},
# )
prismatic_fsdp_wrapping_policy = functools.partial(
_module_wrap_policy,
module_classes={PrismaticProjector},
)
# Add lambda policy for LoRA modules if is_lora is True
if is_lora:
def lambda_policy_fn(module):
return bool(
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
)
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
# Return union (_or_) over constituent policies
# => Note: there is *not* a fall-through policy; any module that isn't covered by the above constituents will
# automatically be folded into the root VLM FSDP instance.
if is_lora:
vla_policies=[
vision_fsdp_wrapping_policy,
llm_wrap_policy,
prismatic_fsdp_wrapping_policy,
lambda_policy
]
else:
vla_policies=[
vision_fsdp_wrapping_policy,
llm_wrap_policy,
prismatic_fsdp_wrapping_policy,
]
return functools.partial(
_or_policy,
policies=vla_policies
)
def offload_fsdp_grad(module):
for _, param in module.named_parameters():
if param.grad is not None:
param.grad = param.grad.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
def load_fsdp_grad(module, device_id):
for _, param in module.named_parameters():
if param.grad is not None:
param.grad = param.grad.to(device_id, non_blocking=True)
torch.cuda.empty_cache()
def offload_fsdp_param_and_grad(module, offload_grad=False):
for _, param in module.named_parameters():
if hasattr(param, "_local_shard"):
param._local_shard = param._local_shard.to("cpu", non_blocking=True)
param.data = param.data.to('cpu', non_blocking=True)
if offload_grad and param.grad is not None:
param.grad = param.grad.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
def load_fsdp_param_and_grad(module, device_id, load_grad=False):
for _, param in module.named_parameters():
if hasattr(param, "_local_shard"):
param._local_shard = param._local_shard.to(device_id, non_blocking=True)
param.data = param.data.to(device_id, non_blocking=True)
if load_grad and param.grad is not None:
param.grad = param.grad.to(device_id, non_blocking=True)
torch.cuda.empty_cache()
def offload_fsdp_optimizer(optimizer):
for param_group in optimizer.param_groups:
for param in param_group['params']:
state = optimizer.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
def load_fsdp_optimizer(optimizer, device_id):
for param_group in optimizer.param_groups:
for param in param_group['params']:
state = optimizer.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to(device_id, non_blocking=True)
torch.cuda.empty_cache()
|