# Copyright 2025 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os from collections import OrderedDict from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Optional, Sequence, Tuple, Union import torch from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME as DIFFUSERS_SAFE_WEIGHTS_INDEX_NAME from diffusers.utils import SAFETENSORS_WEIGHTS_NAME as DIFFUSERS_SAFETENSORS_WEIGHTS_NAME from torch import distributed as dist from torch import nn from tqdm import tqdm from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME from transformers.utils.hub import cached_file, get_checkpoint_shard_files from transformers.utils.import_utils import is_safetensors_available from ..utils import logging from ..utils.helper import empty_cache, get_dtype_size if is_safetensors_available(): from safetensors import safe_open from safetensors.torch import save_file if TYPE_CHECKING: from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin ModelAssets = Union[GenerationConfig, PretrainedConfig, PreTrainedTokenizer, ProcessorMixin] logger = logging.get_logger(__name__) @contextmanager def init_empty_weights(): """ A context manager under which models are initialized with all parameters on the meta device. Borrowed from: https://github.com/huggingface/accelerate/blob/v1.0.0rc1/src/accelerate/big_modeling.py#L57 """ old_register_parameter = nn.Module.register_parameter def register_empty_parameter(module: "nn.Module", name: str, param: "nn.Parameter"): old_register_parameter(module, name, param) if param is not None: param_cls = type(module._parameters[name]) kwargs = module._parameters[name].__dict__ kwargs["requires_grad"] = param.requires_grad module._parameters[name] = param_cls(module._parameters[name].to("meta"), **kwargs) try: nn.Module.register_parameter = register_empty_parameter yield finally: nn.Module.register_parameter = old_register_parameter @dataclass class StateDictIterator: filepath: str prefix: str = '' def __iter__(self) -> Generator[Tuple[str, "torch.Tensor"], None, None]: if self.filepath.endswith(".safetensors"): with safe_open(self.filepath, framework="pt", device="cpu") as f: for key in f.keys(): yield key, f.get_tensor(key) else: state_dict = torch.load(self.filepath, map_location="cpu", weights_only=True, mmap=True) for key in state_dict.keys(): yield key, state_dict[key] def _load_state_dict(weights_path: str, expert_vision_path: str | None = None, **kwargs) -> List["StateDictIterator"]: """ Loads (sharded) state dict in transformers' format. """ cache_kwargs = {"_raise_exceptions_for_missing_entries": False, **kwargs} resolved_weight_file = cached_file(weights_path, SAFE_WEIGHTS_NAME, **cache_kwargs) if resolved_weight_file: return [StateDictIterator(resolved_weight_file)] resolved_weight_file = cached_file(weights_path, SAFE_WEIGHTS_INDEX_NAME, **cache_kwargs) if resolved_weight_file: if expert_vision_path is not None: shard_files, _ = get_checkpoint_shard_files(expert_vision_path, resolved_weight_file, **kwargs) else: shard_files, _ = get_checkpoint_shard_files(weights_path, resolved_weight_file, **kwargs) return [StateDictIterator(shard_file) for shard_file in shard_files] resolved_weight_file = cached_file(weights_path, DIFFUSERS_SAFETENSORS_WEIGHTS_NAME, **cache_kwargs) if resolved_weight_file: return [StateDictIterator(resolved_weight_file)] resolved_weight_file = cached_file(weights_path, DIFFUSERS_SAFE_WEIGHTS_INDEX_NAME, **cache_kwargs) if resolved_weight_file: shard_files, _ = get_checkpoint_shard_files(weights_path, resolved_weight_file, **kwargs) return [StateDictIterator(shard_file) for shard_file in shard_files] resolved_weight_file = cached_file(weights_path, WEIGHTS_NAME, **cache_kwargs) if resolved_weight_file: return [StateDictIterator(resolved_weight_file)] resolved_weight_file = cached_file(weights_path, WEIGHTS_INDEX_NAME, **cache_kwargs) if resolved_weight_file: shard_files, _ = get_checkpoint_shard_files(weights_path, resolved_weight_file, **kwargs) return [StateDictIterator(shard_file) for shard_file in shard_files] raise ValueError(f"Cannot find checkpoint files in {weights_path}.") def _find_submodule(module: "nn.Module", name: str) -> Tuple["nn.Module", str]: """ Finds the leaf module according to the name. """ pieces = name.split(".") for piece in pieces[:-1]: if not hasattr(module, piece): raise ValueError(f"Cannot find {piece} in {module}.") module = getattr(module, piece) return module, pieces[-1] def _dispatch_parameter( module: "nn.Module", name: str, tensor: "torch.Tensor", dtensor_factory: Optional[Callable[["torch.Tensor", Any, Any], "torch.Tensor"]] = None, ) -> None: """ Assigns parameter to an empty model. NOTE: FSDP module must use in-place operators. """ module, name = _find_submodule(module, name) orig_tensor = module._parameters[name].data tensor = tensor.to(orig_tensor) if hasattr(orig_tensor, "device_mesh"): # dtensor if orig_tensor.device.type == "cpu": raise ValueError("Cannot load dtensor on CPU.") device_mesh = getattr(orig_tensor, "device_mesh") placements = getattr(orig_tensor, "placements") module._parameters[name].data.copy_(dtensor_factory(tensor, device_mesh, placements)) else: # not dtensor module._parameters[name].data.copy_(tensor) def _dispatch_buffer( module: "nn.Module", name: str, buffer: "torch.Tensor", ) -> None: """ Assigns buffer to an empty model. """ module, name = _find_submodule(module, name) orig_tensor = module._buffers[name].data module._buffers[name] = buffer.to(orig_tensor) def _init_parameter( module: "nn.Module", name: str, ) -> None: """ Initializes parameter in model. """ pieces = name.split(".") init_func = None for piece in pieces[:-1]: if not hasattr(module, piece): raise ValueError(f"Cannot find {piece} in {module}.") if hasattr(module, "_init_weights"): init_func = getattr(module, "_init_weights") module = getattr(module, piece) if init_func is None: print(module) raise ValueError(f"Cannot retrieve `_init_weights` function in the parents of {module}.") module.apply(init_func) def get_model_prefix(parameter_names): vlm_prefix = '' for param_name in parameter_names: parts = param_name.split('.') if parts[1]=='qwenvl_with_expert' and 'expert' not in parts[2]: vlm_prefix='.'.join(parts[:3])+'.' break return vlm_prefix @torch.no_grad() def load_model_weights( model: Union["nn.Module", "PreTrainedModel"], weights_path: str, init_device: Literal["cpu", "cuda"] = "cuda", dtensor_factory: Optional[Callable[["torch.Tensor", Any, Any], "torch.Tensor"]] = None, load_vlm_only: bool = False, enable_expert_vision: bool = False, expert_vision_path: str | None = None, post_training: bool = False, incremental_training: bool = False, depth_incremental_training: bool = False, norm_qkv: bool = False, adanorm_time: bool = False, ) -> None: """ Loads pre-trained model states in transformers' format. """ buffer_dict = {name: buffer.clone() for name, buffer in model.named_buffers()} parameter_names = {name for name, _ in model.named_parameters()} vlm_parameter_names = {name for name, _ in model.model.qwenvl_with_expert.qwenvl.named_parameters()} print(f'====vlm contains {len(vlm_parameter_names)} paras=====') if expert_vision_path is not None or enable_expert_vision: dino_parameter_names = {name for name, _ in model.model.qwenvl_with_expert.expert_visual.named_parameters()} print(f'====dino contains {len(dino_parameter_names)} paras=====') model.to_empty(device=init_device) if post_training: logger.info_rank0(f">>> Doing Post-Training now, no need to load LLM's embedding weight.") elif incremental_training: logger.info_rank0(f">>> Load pretrained weights for incremental training.") elif load_vlm_only: logger.info_rank0(f">>> Doing Pre-Training now.") else: logger.info_rank0(f">>> Fine-tuneing based on PI0 now.") # TODO state_dict_iterators = _load_state_dict(weights_path, expert_vision_path) vlm_perfix = get_model_prefix(parameter_names) if load_vlm_only else '' for state_dict_iterator in tqdm( state_dict_iterators, desc="Loading checkpoint shards", disable=int(os.getenv("LOCAL_RANK", "-1")) > 0 ): for name, tensor in state_dict_iterator: if 'expert_visual.' in name and not post_training: name = 'model.qwenvl_with_expert.'+name else: name = vlm_perfix+name if name in buffer_dict.keys(): # persistent buffers buffer_dict[name] = tensor.clone() elif name in parameter_names: if incremental_training: try: _dispatch_parameter(model, name, tensor, dtensor_factory) parameter_names.remove(name) except: logger.info_rank0(f">>>The {name} weight need to be reinitialized.") else: parameter_names.remove(name) _dispatch_parameter(model, name, tensor, dtensor_factory) else: if post_training: error_msg = f"Unexpected key '{name}' found in state dict during Post-Training. This is not allowed!!!" logger.info_rank0(error_msg) raise KeyError(error_msg) if expert_vision_path is not None or enable_expert_vision: assert '.expert_visual.' not in name, "vision encoder need to be inited for action expert!" logger.info_rank0(f"Unexpected key in state dict: {name}.") del state_dict_iterator empty_cache() for name, buffer in buffer_dict.items(): _dispatch_buffer(model, name, buffer) if post_training: assert len(parameter_names) == 0, f"Missing {parameter_names} during Post-Training. This is not allowed!!!" if len(parameter_names) > 0: if load_vlm_only and (expert_vision_path is not None or enable_expert_vision) and not incremental_training: num_missing_vlm_para, num_missing_dino_para = 0, 0 for name in parameter_names: if '.paligemma.' in name or '.qwenvl.' in name: num_missing_vlm_para += 1 elif '.expert_visual.' in name: num_missing_dino_para += 1 print(f'====Missing {num_missing_vlm_para} paras in vlm====') print(f'====Missing {num_missing_dino_para} paras in DINO====') assert (all('.paligemma.' not in name for name in parameter_names) or all('.qwenvl.' not in name for name in parameter_names)) and all('.expert_visual.' not in name for name in parameter_names), "Parameters in VLM and Expert_Visual are not loaded when PreTraining!!!" elif incremental_training and not depth_incremental_training: if norm_qkv: assert all('_proj.' in name or '_layernorm.' in name for name in parameter_names), "Only MLP weight can be reinitialized when IncrementalTraining!!!" else: assert all('_proj.' in name or 'gate' in name for name in parameter_names), "Only MLP weight can be reinitialized when IncrementalTraining!!!" elif depth_incremental_training: assert all('depth_align_head.' in name for name in parameter_names), "Only depth align head weight can be reinitialized when IncrementalTraining with Depth Model!!!" elif load_vlm_only: num_missing_vlm_para = 0 for name in parameter_names: if '.paligemma.' in name or '.qwenvl.' in name: num_missing_vlm_para += 1 print(f'====Missing {num_missing_vlm_para} paras in vlm====') assert all('.paligemma.' not in name for name in parameter_names) or all('.qwenvl.' not in name for name in parameter_names), \ "Parameters in VLM are not loaded when PreTraining!!!" logger.info_rank0(f"Find missing key(s) in state dict: {parameter_names}, initialize them.") if adanorm_time: logger.info_rank0(">>> Parameters in AdaNorm has been ZERO initialized.") exclude_keywords = [ "input_layernorm.gamma_beta_gate", "post_attention_layernorm.gamma_beta_gate", "norm.gamma_beta_gate", "input_layernorm.gamma", "post_attention_layernorm.gamma", "norm.gamma", "input_layernorm.beta", "post_attention_layernorm.beta", "norm.beta", "input_layernorm.gate", "post_attention_layernorm.gate", "norm.gate", ] for name in parameter_names: if not adanorm_time: _init_parameter(model, name) else: if not any(keyword in name for keyword in exclude_keywords): _init_parameter(model, name) # we should tie embeddings after loading weights because to_empty() leads to untied weights, # except for fsdp1 (custom init) and fsdp2 (swap tensor) contexts. if getattr(model.config, "tie_word_embeddings", True): try: input_embeddings = model.get_input_embeddings() output_embeddings = model.get_output_embeddings() output_embeddings._parameters["weight"] = input_embeddings._parameters["weight"] except Exception as e: logger.info_rank0(f"Failed to tie embeddings: {e}") def _get_shard_info( state_dict: Dict[str, "torch.Tensor"], save_dtype: Optional[Union[str, "torch.dtype"]], shard_size: int, safe_serialization: bool, ) -> Tuple[bool, int, Dict[str, str]]: """ Gets the shard information, should be executed at rank 0. """ current_size, total_size = 0, 0 current_shard, shard_list = [], [] for name, tensor in state_dict.items(): if isinstance(save_dtype, str): dtype = getattr(torch, save_dtype) elif isinstance(save_dtype, torch.dtype): dtype = save_dtype else: dtype = tensor.dtype tensor_size = tensor.numel() * get_dtype_size(dtype) # dtensor's numel == tensor's numel if current_size != 0 and current_size + tensor_size > shard_size: total_size += current_size shard_list.append(current_shard) current_size = 0 current_shard = [] current_size += tensor_size current_shard.append(name) if current_size != 0: total_size += current_size shard_list.append(current_shard) weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME num_shards = len(shard_list) weight_map = OrderedDict() is_sharded = None if num_shards == 1: is_sharded = False for name in shard_list[0]: weight_map[name] = weights_name else: is_sharded = True for shard_idx, shard in enumerate(shard_list): prefix, extension = weights_name.rsplit(".", maxsplit=1) file_name = f"{prefix}-{shard_idx + 1:05d}-of-{num_shards:05d}.{extension}" for name in shard: weight_map[name] = file_name return is_sharded, total_size, weight_map def _save_state_dict( state_dict: Dict[str, "torch.Tensor"], path_to_save: "os.PathLike", safe_serialization: bool, ) -> None: """ Save function. """ if os.path.exists(path_to_save): os.remove(path_to_save) if safe_serialization: save_file(state_dict, path_to_save, metadata={"format": "pt"}) else: torch.save(state_dict, path_to_save) @torch.no_grad() def save_model_weights( output_dir: Union[str, "os.PathLike"], state_dict: Dict[str, "torch.Tensor"], global_rank: Optional[int] = None, save_dtype: Optional[Union[str, "torch.dtype"]] = "bfloat16", shard_size: int = 5_000_000_000, safe_serialization: bool = True, model_assets: Optional[Sequence["ModelAssets"]] = None, ) -> None: """ Saves full model weights. The model parameters should be either tensor or dtensor. If global_rank is given, it will assume it is executed on all ranks. """ os.makedirs(output_dir, exist_ok=True) is_sharded, total_size, weight_map = _get_shard_info(state_dict, save_dtype, shard_size, safe_serialization) full_state_dict = OrderedDict() prev_file_name = None for name, tensor in state_dict.items(): if hasattr(tensor.data, "full_tensor"): # dtensor tensor = tensor.data.full_tensor() else: tensor = tensor.data if save_dtype: tensor = tensor.to(dtype=getattr(torch, save_dtype) if isinstance(save_dtype, str) else save_dtype) if prev_file_name is not None and weight_map[name] != prev_file_name: if global_rank is None or global_rank == 0: _save_state_dict(full_state_dict, os.path.join(output_dir, prev_file_name), safe_serialization) full_state_dict = OrderedDict() empty_cache() if global_rank is not None and dist.is_initialized(): # avoid process hanging torch.cuda.synchronize() dist.barrier() if global_rank is None or global_rank == 0: full_state_dict[name] = tensor.detach().cpu() prev_file_name = weight_map[name] del tensor if global_rank is None or global_rank == 0: if len(full_state_dict): _save_state_dict(full_state_dict, os.path.join(output_dir, prev_file_name), safe_serialization) if is_sharded: index = { "metadata": {"total_size": total_size}, "weight_map": weight_map, } index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME with open(os.path.join(output_dir, index_file), "w", encoding="utf-8") as f: content = json.dumps(index, indent=2, sort_keys=True) + "\n" f.write(content) logger.info(f"Model weight splits saved in {output_dir}.") else: logger.info(f"Model weights saved at {os.path.join(output_dir, prev_file_name)}.") if model_assets is not None: for model_asset in model_assets: if hasattr(model_asset, "save_pretrained"): model_asset.save_pretrained(output_dir) else: logger.warning(f"Model asset {model_asset} should implement `save_pretrained`.") def save_model_assets(output_dir: Union[str, "os.PathLike"], model_assets: Sequence["ModelAssets"]): for model_asset in model_assets: if hasattr(model_asset, "save_pretrained"): model_asset.save_pretrained(output_dir) else: logger.warning(f"Model asset {model_asset} should implement `save_pretrained`.")