| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import functools |
| from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy |
| from transformers.trainer_pt_utils import get_module_class_from_name |
| import torch |
| import torch.nn as nn |
| from verl.utils.vla_utils.openvla_oft.modeling_prismatic import PrismaticProjector |
|
|
| def init_fn(x: torch.nn.Module): |
| if not torch.distributed.get_rank() == 0: |
| x = x.to_empty(device=torch.cuda.current_device(), recurse=False) |
| torch.cuda.empty_cache() |
| return x |
|
|
|
|
| def get_init_weight_context_manager(use_meta_tensor=True): |
| from accelerate import init_empty_weights |
| cpu_init_weights = lambda: torch.device('cpu') |
| if use_meta_tensor: |
| init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights |
| else: |
| init_context = cpu_init_weights |
| return init_context |
|
|
|
|
| |
| |
| def get_fsdp_wrap_policy(module, config=None): |
| if config is None: |
| config = {} |
|
|
| if config.get('disable', False): |
| return None |
|
|
| default_transformer_cls_names_to_wrap = getattr(module, "_no_split_modules", None) |
| fsdp_transformer_layer_cls_to_wrap = config.get("transformer_layer_cls_to_wrap", |
| default_transformer_cls_names_to_wrap) |
| min_num_params = config.get('min_num_params', 0) |
| auto_wrap_policy = None |
| if min_num_params > 0: |
| auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) |
| elif fsdp_transformer_layer_cls_to_wrap is not None: |
| transformer_cls_to_wrap = set() |
| for layer_class in fsdp_transformer_layer_cls_to_wrap: |
| transformer_cls = get_module_class_from_name(module, layer_class) |
| if transformer_cls is None: |
| raise Exception("Could not find the transformer layer class to wrap in the model.") |
| else: |
| transformer_cls_to_wrap.add(transformer_cls) |
|
|
| auto_wrap_policy = functools.partial( |
| transformer_auto_wrap_policy, |
| |
| transformer_layer_cls=transformer_cls_to_wrap, |
| ) |
| return auto_wrap_policy |
|
|
| def get_fsdp_wrap_policy_vla(module, config=None, is_lora=False): |
| |
| from timm.models.vision_transformer import Block, VisionTransformer |
| from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy, lambda_auto_wrap_policy |
| vit_wrap_policy = functools.partial(_module_wrap_policy, module_classes={VisionTransformer}) |
| transformer_block_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) |
| vision_fsdp_wrapping_policy = functools.partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) |
|
|
| |
| |
| |
| |
| |
| default_transformer_cls_names_to_wrap = getattr(module.language_model, "_no_split_modules", None) |
| |
| fsdp_transformer_layer_cls_to_wrap = default_transformer_cls_names_to_wrap |
|
|
| llm_wrap_policy = None |
| |
| if fsdp_transformer_layer_cls_to_wrap is not None: |
| transformer_cls_to_wrap = set() |
| for layer_class in fsdp_transformer_layer_cls_to_wrap: |
| print("layer_class is :", layer_class) |
| transformer_cls = get_module_class_from_name(module, layer_class) |
| if transformer_cls is None: |
| raise Exception("Could not find the transformer layer class to wrap in the model.") |
| else: |
| transformer_cls_to_wrap.add(transformer_cls) |
|
|
| llm_wrap_policy = functools.partial( |
| transformer_auto_wrap_policy, |
| |
| transformer_layer_cls=transformer_cls_to_wrap, |
| ) |
| print("llm_wrap_policy:",llm_wrap_policy) |
| assert llm_wrap_policy is not None |
|
|
| |
|
|
|
|
|
|
|
|
| |
| |
| |
| |
| |
| prismatic_fsdp_wrapping_policy = functools.partial( |
| _module_wrap_policy, |
| module_classes={PrismaticProjector}, |
| ) |
|
|
| |
| |
| if is_lora: |
| def lambda_policy_fn(module): |
| return bool( |
| len(list(module.named_children())) == 0 |
| and getattr(module, "weight", None) is not None |
| and module.weight.requires_grad |
| ) |
| lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) |
|
|
|
|
| |
| |
| |
| if is_lora: |
| vla_policies=[ |
| vision_fsdp_wrapping_policy, |
| llm_wrap_policy, |
| prismatic_fsdp_wrapping_policy, |
| lambda_policy |
| ] |
| else: |
| vla_policies=[ |
| vision_fsdp_wrapping_policy, |
| llm_wrap_policy, |
| prismatic_fsdp_wrapping_policy, |
| ] |
| |
| return functools.partial( |
| _or_policy, |
| policies=vla_policies |
| ) |
|
|
|
|
| def offload_fsdp_grad(module): |
| for _, param in module.named_parameters(): |
| if param.grad is not None: |
| param.grad = param.grad.to("cpu", non_blocking=True) |
| torch.cuda.empty_cache() |
|
|
|
|
| def load_fsdp_grad(module, device_id): |
| for _, param in module.named_parameters(): |
| if param.grad is not None: |
| param.grad = param.grad.to(device_id, non_blocking=True) |
| torch.cuda.empty_cache() |
|
|
|
|
| def offload_fsdp_param_and_grad(module, offload_grad=False): |
| for _, param in module.named_parameters(): |
| if hasattr(param, "_local_shard"): |
| param._local_shard = param._local_shard.to("cpu", non_blocking=True) |
| param.data = param.data.to('cpu', non_blocking=True) |
| if offload_grad and param.grad is not None: |
| param.grad = param.grad.to("cpu", non_blocking=True) |
| torch.cuda.empty_cache() |
|
|
|
|
| def load_fsdp_param_and_grad(module, device_id, load_grad=False): |
| for _, param in module.named_parameters(): |
| if hasattr(param, "_local_shard"): |
| param._local_shard = param._local_shard.to(device_id, non_blocking=True) |
| param.data = param.data.to(device_id, non_blocking=True) |
| if load_grad and param.grad is not None: |
| param.grad = param.grad.to(device_id, non_blocking=True) |
| torch.cuda.empty_cache() |
|
|
|
|
| def offload_fsdp_optimizer(optimizer): |
| for param_group in optimizer.param_groups: |
| for param in param_group['params']: |
| state = optimizer.state[param] |
| for key, value in state.items(): |
| if isinstance(value, torch.Tensor): |
| state[key] = value.to("cpu", non_blocking=True) |
| torch.cuda.empty_cache() |
|
|
|
|
| def load_fsdp_optimizer(optimizer, device_id): |
| for param_group in optimizer.param_groups: |
| for param in param_group['params']: |
| state = optimizer.state[param] |
| for key, value in state.items(): |
| if isinstance(value, torch.Tensor): |
| state[key] = value.to(device_id, non_blocking=True) |
| torch.cuda.empty_cache() |
|
|