Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Any, Dict, Optional | |
| import torch | |
| from cosmos_predict1.utils import log | |
| # Substrings to ignore when processing state dicts | |
| substrings_to_ignore = [ | |
| "_extra_state", # Extra states (BytesIO type) added by TransformerEngine for FP8 handling | |
| ] | |
| def identify_checkpoint_backend(state_dict: dict[str, torch.Tensor]) -> str: | |
| """ | |
| Identify the backend of the checkpoint (PyTorch or TransformerEngine) | |
| Args: | |
| state_dict (dict[str, torch.Tensor]): The state dict to check | |
| Returns: | |
| str: The backend of the checkpoint | |
| """ | |
| for key in state_dict.keys(): | |
| if "self_attention.layernorm_qkv.query_weight" in key: | |
| return "transformer_engine" | |
| elif "attention.wq.weight" in key: | |
| return "pytorch" | |
| raise ValueError("Could not identify the backend of the checkpoint") | |
| def get_partial_state_dict( | |
| state_dict: dict[str, torch.Tensor], | |
| prefix: str, | |
| ) -> dict[str, torch.Tensor]: | |
| """ | |
| Get a partial state dict with keys starting with the given prefix | |
| """ | |
| return {k: v for k, v in state_dict.items() if k.startswith(prefix)} | |
| def process_state_dict( | |
| state_dict: dict[str, torch.Tensor], | |
| device: str = None, | |
| dtype: torch.dtype = None, | |
| prefix_to_remove: Optional[str] = None, | |
| ) -> dict[str, torch.Tensor]: | |
| """ | |
| - Remove items with substring "_extra_state" in keys (TransformerEngine adds these for FP8) | |
| - Move tensors to specified device and dtype if provided | |
| Args: | |
| state_dict (dict[str, torch.Tensor]): The state dict to process | |
| device (str, optional): The device to move tensors to. Defaults to None. | |
| dtype (torch.dtype, optional): The dtype to move tensors to. Defaults to None. | |
| prefix_to_remove (str, optional): The prefix to remove from the keys of the state dict. Defaults to None. | |
| Returns: | |
| dict[str, torch.Tensor]: The processed state dict | |
| """ | |
| new_state_dict = {} | |
| tensor_kwargs = {} | |
| if device is not None: | |
| tensor_kwargs["device"] = device | |
| if dtype is not None: | |
| tensor_kwargs["dtype"] = dtype | |
| for key, value in state_dict.items(): | |
| # Check if any of the substrings to ignore are in the key | |
| skip = False | |
| for substr in substrings_to_ignore: | |
| if substr in key: | |
| skip = True | |
| break | |
| if skip: | |
| continue | |
| if len(tensor_kwargs) > 0: | |
| value = value.to(**tensor_kwargs) | |
| if prefix_to_remove is not None and key.startswith(prefix_to_remove): | |
| key = key[len(prefix_to_remove) :] | |
| new_state_dict[key] = value | |
| return new_state_dict | |
| def obtain_tensor_parallel_state_dict( | |
| whole_model_state_dict: dict[str, torch.Tensor], | |
| tensor_parallel_size: int, | |
| tensor_parallel_rank: int, | |
| model_config, | |
| target_backend: str = None, | |
| ) -> dict[str, torch.Tensor]: | |
| """ | |
| Obtain the tensor parallel state dict shard for the current rank. | |
| Args: | |
| whole_model_state_dict (dict[str, torch.Tensor]): The complete model state dict. | |
| tensor_parallel_size (int): The number of tensor parallel devices. | |
| tensor_parallel_rank (int): The rank of the current tensor parallel device. | |
| model_config: The model configuration. | |
| target_backend (str, optional): The target backend format ('pytorch', 'transformer_engine', or 'huggingface'). If not specified, the source backend will be used. | |
| Returns: | |
| dict[str, torch.Tensor]: The updated state dict shard for the current tensor parallel rank. | |
| """ | |
| new_state_dict_shard = {} | |
| whole_model_state_dict = process_state_dict(whole_model_state_dict) | |
| source_backend = identify_checkpoint_backend(whole_model_state_dict) | |
| if source_backend != "pytorch": | |
| # Convert the checkpoint to PyTorch backend for checkpoint sharding | |
| whole_model_state_dict = maybe_convert_checkpoint_to_backend( | |
| whole_model_state_dict, target_backend="pytorch", model_config=model_config, source_backend=source_backend | |
| ) | |
| n_heads = model_config["n_heads"] | |
| n_kv_heads = model_config["n_kv_heads"] | |
| dim = model_config["dim"] | |
| context_dim = model_config["context_dim"] | |
| for key, value in whole_model_state_dict.items(): | |
| prefix = "model." if key.startswith("model.") else "" # LLM's model prefix | |
| prefix = "transformer." if key.startswith("transformer.") else prefix # VIT's model prefix | |
| key = key.replace(prefix, "") | |
| if key.startswith("layers."): | |
| layer_index = int(key.split("layers.")[1].split(".")[0]) | |
| if layer_index >= model_config["n_layers"]: | |
| log.warning( | |
| f"Layer index {layer_index} is greater than the number of layers {model_config['n_layers']}. Skipping this layer." | |
| ) | |
| continue | |
| if ".attention.wq.weight" in key or "cross_attention.wq.weight" in key: | |
| value = torch.chunk(value.view(n_heads, -1, dim), tensor_parallel_size, dim=0)[tensor_parallel_rank] | |
| value = value.reshape(-1, dim) | |
| elif ".attention.wk.weight" in key or ".attention.wv.weight" in key: | |
| value = torch.chunk(value.view(n_kv_heads, -1, dim), tensor_parallel_size, dim=0)[tensor_parallel_rank] | |
| value = value.reshape(-1, dim) | |
| elif "cross_attention.wk.weight" in key or "cross_attention.wv.weight" in key: | |
| assert context_dim is not None | |
| value = torch.chunk(value.view(n_kv_heads, -1, context_dim), tensor_parallel_size, dim=0)[ | |
| tensor_parallel_rank | |
| ] | |
| value = value.reshape(-1, context_dim) | |
| elif "feed_forward.w1.weight" in key or "feed_forward.w3.weight" in key or "medusa_head" in key: | |
| value = torch.chunk(value, tensor_parallel_size, dim=0)[tensor_parallel_rank] | |
| elif "feed_forward.w2.weight" in key or ".attention.wo.weight" in key or "cross_attention.wo.weight" in key: | |
| value = torch.chunk(value, tensor_parallel_size, dim=1)[tensor_parallel_rank] | |
| else: | |
| # Handle non-layer weights | |
| if key == "tok_embeddings.weight" or key == "output.weight" or "medusa_head" in key: | |
| value = torch.chunk(value, tensor_parallel_size, dim=0)[tensor_parallel_rank] | |
| new_state_dict_shard[prefix + key] = value | |
| if target_backend is None: | |
| target_backend = source_backend | |
| new_state_dict_shard = maybe_convert_checkpoint_to_backend( | |
| new_state_dict_shard, | |
| target_backend=target_backend, | |
| model_config=model_config, | |
| is_tensor_parallel_shard=True, | |
| tensor_parallel_size=tensor_parallel_size, | |
| ) | |
| return new_state_dict_shard | |
| def merge_tensor_parallel_state_dicts( | |
| state_dict_shards: list[dict[str, torch.Tensor]], | |
| model_config, | |
| target_backend: str = None, | |
| ) -> dict[str, torch.Tensor]: | |
| """ | |
| Merge tensor parallel state dict shards into a whole model state dict. | |
| Args: | |
| state_dict_shards (List[Dict[str, torch.Tensor]]): The list of state dict shards to merge. | |
| model_config: The model configuration. | |
| target_backend (str, optional): The target backend format ('pytorch', 'transformer_engine', or 'huggingface'). If not specified, the source backend will be used. | |
| Returns: | |
| Dict[str, torch.Tensor]: The merged state dict. | |
| """ | |
| state_dict_shards = [process_state_dict(shard, device="cpu") for shard in state_dict_shards] | |
| tensor_parallel_size = len(state_dict_shards) | |
| source_backend = identify_checkpoint_backend(state_dict_shards[0]) | |
| if source_backend != "pytorch": | |
| log.critical(f"Converting from {source_backend} to PyTorch backend for tensor parallel checkpoint merging.") | |
| state_dict_shards = [ | |
| maybe_convert_checkpoint_to_backend( | |
| shard, | |
| target_backend="pytorch", | |
| model_config=model_config, | |
| source_backend=source_backend, | |
| is_tensor_parallel_shard=True, | |
| tensor_parallel_size=tensor_parallel_size, | |
| ) | |
| for shard in state_dict_shards | |
| ] | |
| n_heads = model_config["n_heads"] | |
| n_kv_heads = model_config["n_kv_heads"] | |
| n_local_heads = n_heads // tensor_parallel_size | |
| n_local_kv_heads = n_kv_heads // tensor_parallel_size | |
| dim = model_config["dim"] | |
| context_dim = model_config["context_dim"] | |
| head_dim = model_config["head_dim"] | |
| if head_dim is None: | |
| head_dim = model_config["dim"] // model_config["n_heads"] | |
| query_dim = head_dim * n_heads | |
| key_value_dim = head_dim * n_kv_heads | |
| merged_state_dict = {} | |
| for key in state_dict_shards[0].keys(): | |
| prefix = "model." if key.startswith("model.") else "" | |
| key_without_prefix = key[len(prefix) :] | |
| if key_without_prefix.startswith("layers."): | |
| layer_index = int(key_without_prefix.split("layers.")[1].split(".")[0]) | |
| if layer_index >= model_config["n_layers"]: | |
| log.warning( | |
| f"Layer index {layer_index} is greater than the number of layers {model_config['n_layers']}. Skipping this layer." | |
| ) | |
| continue | |
| if key_without_prefix == "tok_embeddings.weight" or key_without_prefix == "output.weight": | |
| merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=0) | |
| elif ".attention.wq.weight" in key or "cross_attention.wq.weight" in key: | |
| chunks = [shard[key].view(n_local_heads, head_dim, dim) for shard in state_dict_shards] | |
| merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(query_dim, dim) | |
| elif ".attention.wk.weight" in key or ".attention.wv.weight" in key: | |
| chunks = [shard[key].view(n_local_kv_heads, head_dim, dim) for shard in state_dict_shards] | |
| merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(key_value_dim, dim) | |
| elif "cross_attention.wk.weight" in key or "cross_attention.wv.weight" in key: | |
| chunks = [shard[key].view(n_local_kv_heads, head_dim, context_dim) for shard in state_dict_shards] | |
| merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(key_value_dim, context_dim) | |
| elif "feed_forward.w1.weight" in key or "feed_forward.w3.weight" in key or "medusa_head" in key: | |
| merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=0) | |
| elif "feed_forward.w2.weight" in key or ".attention.wo.weight" in key or "cross_attention.wo.weight" in key: | |
| merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=1) | |
| else: | |
| avg_tensor = torch.stack([shard[key] for shard in state_dict_shards]).mean(dim=0) | |
| # make sure shard-0 is close to the average tensor | |
| assert torch.allclose(state_dict_shards[0][key], avg_tensor, atol=5e-2, rtol=0.1), ( | |
| f"Shard-0 tensor {key} is not close to the average tensor. " | |
| f"Max diff: {torch.max(torch.abs(state_dict_shards[0][key] - avg_tensor))}, " | |
| ) | |
| merged_state_dict[key] = avg_tensor | |
| assert "norm" in key, f"Assumed the key {key} is a norm layer, which should be the same across shards." | |
| if target_backend is None: | |
| target_backend = source_backend | |
| return maybe_convert_checkpoint_to_backend( | |
| merged_state_dict, target_backend=target_backend, model_config=model_config | |
| ) | |
| def te_to_pytorch_state_dict( | |
| te_state_dict: Dict[str, torch.Tensor], model_config, tensor_parallel_size: int = 1 | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Convert a TransformerEngine state dict to PyTorch state dict | |
| Args: | |
| te_state_dict (Mapping[str, torch.Tensor]): The TransformerEngine state dict | |
| model_config: The model configuration | |
| tensor_parallel_size (int): The tensor parallel size. Defaults to 1 (i.e., not a tensor parallel shard). | |
| Returns: | |
| Mapping[str, torch.Tensor]: The PyTorch state dict | |
| """ | |
| if hasattr(model_config, "asdict"): | |
| model_config = model_config.asdict() | |
| pytorch_state_dict = {} | |
| replacement_rules = [ | |
| # Self-attention modules | |
| (".self_attention.layernorm_qkv.layer_norm_weight", ".attention_norm.weight"), | |
| (".self_attention.layernorm_qkv.query_weight", ".attention.wq.weight"), | |
| (".self_attention.layernorm_qkv.key_weight", ".attention.wk.weight"), | |
| (".self_attention.layernorm_qkv.value_weight", ".attention.wv.weight"), | |
| (".self_attention.proj.weight", ".attention.wo.weight"), | |
| (".self_attention.", ".attention."), # Handle the rest modules such as q_norm and k_norm | |
| # MLP modules | |
| (".layernorm_mlp.layer_norm_weight", ".ffn_norm.weight"), | |
| (".layernorm_mlp.fc2_weight", ".feed_forward.w2.weight"), | |
| # Cross-attention modules | |
| (".inter_attention.layernorm_query.query_weight", ".cross_attention.wq.weight"), | |
| (".inter_attention.key_value.key_weight", ".cross_attention.wk.weight"), | |
| (".inter_attention.key_value.value_weight", ".cross_attention.wv.weight"), | |
| (".inter_attention.proj.weight", ".cross_attention.wo.weight"), | |
| (".inter_attention.layernorm_query.layer_norm_weight", ".cross_attention_norm.weight"), | |
| (".inter_attention.", ".cross_attention."), # Handle the rest modules such as q_norm and k_norm | |
| ] | |
| head_dim = model_config["head_dim"] | |
| if head_dim is None: | |
| head_dim = model_config["dim"] // model_config["n_heads"] | |
| for old_key, value in te_state_dict.items(): | |
| new_key = old_key | |
| for old_substr, new_substr in replacement_rules: | |
| if old_substr in new_key: | |
| new_key = new_key.replace(old_substr, new_substr) | |
| break | |
| # Handle the fused w1 and w3 case | |
| if "layernorm_mlp.fc1_weight" in old_key: | |
| fused_weight = value | |
| split_point = fused_weight.shape[0] // 2 | |
| w1_weight = fused_weight[:split_point] | |
| w3_weight = fused_weight[split_point:] | |
| w1_key = new_key.replace("layernorm_mlp.fc1_weight", "feed_forward.w1.weight") | |
| w3_key = new_key.replace("layernorm_mlp.fc1_weight", "feed_forward.w3.weight") | |
| pytorch_state_dict[w1_key] = w1_weight | |
| pytorch_state_dict[w3_key] = w3_weight | |
| else: | |
| if model_config["pytorch_rope_version"] == "v1": | |
| # If the model use qk normalization, we will use the same PyTorch RoPE operations as the TE version. | |
| # Thus, we do not need to permute the weights. | |
| if "query_weight" in old_key: | |
| value = inverse_permute_weight( | |
| value, | |
| n_heads=model_config["n_heads"] // tensor_parallel_size, | |
| dim1=head_dim * model_config["n_heads"] // tensor_parallel_size, | |
| dim2=model_config["dim"], | |
| ) | |
| elif "key_weight" in old_key: | |
| value = inverse_permute_weight( | |
| value, | |
| n_heads=model_config["n_kv_heads"] // tensor_parallel_size, | |
| dim1=head_dim * model_config["n_kv_heads"] // tensor_parallel_size, | |
| dim2=model_config["context_dim"] if "inter_attention" in old_key else model_config["dim"], | |
| ) | |
| pytorch_state_dict[new_key] = value | |
| return pytorch_state_dict | |
| def pytorch_to_te_state_dict( | |
| pytorch_state_dict: Dict[str, torch.Tensor], model_config, tensor_parallel_size: int = 1 | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Convert a PyTorch state dict to TransformerEngine state dict | |
| Args: | |
| pytorch_state_dict (Mapping[str, torch.Tensor]): The PyTorch state dict | |
| model_config: The model configuration | |
| tensor_parallel_size (int): The tensor parallel size. Defaults to 1 (i.e., not a tensor parallel shard). | |
| Returns: | |
| Mapping[str, torch.Tensor]: The TransformerEngine | |
| """ | |
| if hasattr(model_config, "asdict"): | |
| model_config = model_config.asdict() | |
| te_state_dict = {} | |
| replacement_rules = [ | |
| # Self-attention modules | |
| (".attention_norm.weight", ".self_attention.layernorm_qkv.layer_norm_weight"), | |
| (".attention.wq.weight", ".self_attention.layernorm_qkv.query_weight"), | |
| (".attention.wk.weight", ".self_attention.layernorm_qkv.key_weight"), | |
| (".attention.wv.weight", ".self_attention.layernorm_qkv.value_weight"), | |
| (".attention.wo.weight", ".self_attention.proj.weight"), | |
| (".attention.", ".self_attention."), | |
| # MLP modules | |
| (".ffn_norm.weight", ".layernorm_mlp.layer_norm_weight"), | |
| (".feed_forward.w2.weight", ".layernorm_mlp.fc2_weight"), | |
| # Cross-attention modules | |
| (".cross_attention_norm.weight", ".inter_attention.layernorm_query.layer_norm_weight"), | |
| (".cross_attention.wq.weight", ".inter_attention.layernorm_query.query_weight"), | |
| (".cross_attention.wk.weight", ".inter_attention.key_value.key_weight"), | |
| (".cross_attention.wv.weight", ".inter_attention.key_value.value_weight"), | |
| (".cross_attention.wo.weight", ".inter_attention.proj.weight"), | |
| (".cross_attention.", ".inter_attention."), | |
| ] | |
| head_dim = model_config["head_dim"] | |
| if head_dim is None: | |
| head_dim = model_config["dim"] // model_config["n_heads"] | |
| for old_key, value in pytorch_state_dict.items(): | |
| new_key = old_key | |
| for new_substr, old_substr in replacement_rules: | |
| if new_substr in new_key: | |
| new_key = new_key.replace(new_substr, old_substr) | |
| break | |
| # Handle the split w1 and w3 case | |
| if "feed_forward.w1.weight" in old_key: | |
| w1_weight = value | |
| w3_key = old_key.replace("feed_forward.w1.weight", "feed_forward.w3.weight") | |
| if w3_key in pytorch_state_dict: | |
| w3_weight = pytorch_state_dict[w3_key] | |
| fused_weight = torch.cat([w1_weight, w3_weight], dim=0) | |
| new_key = new_key.replace("feed_forward.w1.weight", "layernorm_mlp.fc1_weight") | |
| te_state_dict[new_key] = fused_weight | |
| else: | |
| te_state_dict[new_key] = value | |
| elif "feed_forward.w3.weight" in old_key: | |
| # Skip w3 weights as they're handled with w1 | |
| continue | |
| else: | |
| if model_config["pytorch_rope_version"] == "v1": | |
| # If the model use qk normalization, we will use the same PyTorch RoPE operations as the TE version. | |
| # Thus, we do not need to permute the weights. | |
| if "attention.wq" in old_key: | |
| value = permute_weight( | |
| value, | |
| n_heads=model_config["n_heads"] // tensor_parallel_size, | |
| dim1=head_dim * model_config["n_heads"] // tensor_parallel_size, | |
| dim2=model_config["dim"], | |
| ) | |
| elif "attention.wk" in old_key: | |
| value = permute_weight( | |
| value, | |
| n_heads=model_config["n_kv_heads"] // tensor_parallel_size, | |
| dim1=head_dim * model_config["n_kv_heads"] // tensor_parallel_size, | |
| dim2=model_config["context_dim"] if "cross_attention" in old_key else model_config["dim"], | |
| ) | |
| te_state_dict[new_key] = value | |
| return te_state_dict | |
| def permute_weight(w: torch.Tensor, n_heads: int, dim1: int, dim2: int) -> torch.Tensor: | |
| """ | |
| Helper function for converting checkpoints from PyTorch to TransformerEngine | |
| Permute the query weight or key weight of each attention layer | |
| Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py | |
| Args: | |
| w (torch.Tensor): The weight tensor to permute | |
| n_heads (int): The number of attention heads | |
| dim1 (int): The first dimension of the weight tensor | |
| dim2 (int): The second dimension of the weight tensor | |
| Returns: | |
| torch.Tensor: The permuted weight tensor | |
| """ | |
| return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) | |
| def inverse_permute_weight(w: torch.Tensor, n_heads: int, dim1: int, dim2: int) -> torch.Tensor: | |
| """ | |
| Helper function for converting checkpoints from TransformerEngine to PyTorch | |
| Permute the query weight or key weight of each attention layer | |
| Args: | |
| w (torch.Tensor): The weight tensor to permute | |
| n_heads (int): The number of attention heads | |
| dim1 (int): The first dimension of the weight tensor | |
| dim2 (int): The second dimension of the weight tensor | |
| Returns: | |
| torch.Tensor: The permuted weight tensor | |
| """ | |
| return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) | |
| def pytorch_to_hf_state_dict( | |
| state_dict: Dict[str, torch.Tensor], model_config: Dict[str, Any], tensor_parallel_size: int = 1 | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Convert a PyTorch state dict to HuggingFace format for LLM models. | |
| Args: | |
| state_dict (Mapping[str, torch.Tensor]): | |
| The original PyTorch model's state dictionary. | |
| This is a mapping where keys are layer names and values are the corresponding PyTorch tensors | |
| containing the model weights. | |
| model_config (Mapping[str, Any]): | |
| The configuration of the model. This dictionary contains parameters such as: | |
| - n_layers: (int) The number of transformer layers. | |
| - n_heads: (int) The number of attention heads. | |
| - dim: (int) The hidden size of the model. | |
| - n_kv_heads: (int, optional) The number of key-value heads for multi-query attention. | |
| Returns: | |
| Mapping[str, torch.Tensor]: | |
| The converted HuggingFace state dictionary. This dictionary maps HuggingFace transformer-compatible | |
| layer names to the corresponding model weights. | |
| """ | |
| not_supported_key_substrings = ["cross_attention", "q_norm", "k_norm"] | |
| for key in state_dict.keys(): | |
| if any(substr in key for substr in not_supported_key_substrings): | |
| raise ValueError(f"Key {key} is not supported in HuggingFace format.") | |
| assert tensor_parallel_size == 1, "Tensor parallel size > 1 is not supported for HuggingFace model export." | |
| hf_state_dict = {} | |
| n_layers = model_config["n_layers"] | |
| n_heads = model_config["n_heads"] | |
| dim = model_config["dim"] | |
| head_dim = model_config["head_dim"] | |
| if head_dim is None: | |
| head_dim = model_config["dim"] // model_config["n_heads"] | |
| num_key_value_heads = model_config.get("n_kv_heads", n_heads) | |
| key_value_dim = head_dim * num_key_value_heads | |
| for layer_i in range(n_layers): | |
| pt_prefix = f"layers.{layer_i}." | |
| hf_prefix = f"model.layers.{layer_i}." | |
| wq = state_dict[f"{pt_prefix}attention.wq.weight"] | |
| wk = state_dict[f"{pt_prefix}attention.wk.weight"] | |
| if model_config["pytorch_rope_version"] == "v1": | |
| wq = permute_weight( | |
| wq, | |
| n_heads=n_heads, | |
| dim1=dim, | |
| dim2=dim, | |
| ) | |
| wk = permute_weight( | |
| wk, | |
| n_heads=num_key_value_heads, | |
| dim1=key_value_dim, | |
| dim2=dim, | |
| ) | |
| hf_state_dict[f"{hf_prefix}self_attn.q_proj.weight"] = wq | |
| hf_state_dict[f"{hf_prefix}self_attn.k_proj.weight"] = wk | |
| hf_state_dict[f"{hf_prefix}self_attn.v_proj.weight"] = state_dict[f"{pt_prefix}attention.wv.weight"] | |
| hf_state_dict[f"{hf_prefix}self_attn.o_proj.weight"] = state_dict[f"{pt_prefix}attention.wo.weight"] | |
| hf_state_dict[f"{hf_prefix}mlp.gate_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w1.weight"] | |
| hf_state_dict[f"{hf_prefix}mlp.down_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w2.weight"] | |
| hf_state_dict[f"{hf_prefix}mlp.up_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w3.weight"] | |
| hf_state_dict[f"{hf_prefix}input_layernorm.weight"] = state_dict[f"{pt_prefix}attention_norm.weight"] | |
| hf_state_dict[f"{hf_prefix}post_attention_layernorm.weight"] = state_dict[f"{pt_prefix}ffn_norm.weight"] | |
| # Add non-layer weights | |
| hf_state_dict["model.embed_tokens.weight"] = state_dict["tok_embeddings.weight"] | |
| hf_state_dict["model.norm.weight"] = state_dict["norm.weight"] | |
| hf_state_dict["lm_head.weight"] = state_dict["output.weight"] | |
| return hf_state_dict | |
| def maybe_convert_checkpoint_to_backend( | |
| state_dict: Dict[str, torch.Tensor], | |
| target_backend: str, | |
| model_config, | |
| source_backend: str = None, | |
| is_tensor_parallel_shard: bool = False, | |
| tensor_parallel_size: int = None, | |
| ): | |
| """ | |
| Identify the backend of the checkpoint and convert to the target backend if necessary. | |
| This function checks the current backend of the state_dict and converts it to the target backend | |
| if they don't match. It supports conversions between PyTorch, TransformerEngine, and HuggingFace backends. | |
| Args: | |
| state_dict (Dict[str, torch.Tensor]): The model state dictionary to convert. | |
| target_backend (str): The desired backend format ('pytorch', 'transformer_engine', or 'huggingface'). | |
| model_config: Configuration of the model, used in conversion process. | |
| source_backend (str, optional): The current backend of the state_dict. If not specified, the function will identify the backend. | |
| is_tensor_parallel_shard (bool, optional): Whether the state_dict is a tensor parallel shard. Defaults to False. | |
| tensor_parallel_size (int, optional): The tensor parallel size. If not specified, the model_config will be modified. | |
| Returns: | |
| Dict[str, torch.Tensor]: The converted state dictionary in the target backend format. | |
| Raises: | |
| ValueError: If the conversion between the identified backend and target backend is not supported. | |
| """ | |
| # Identify the current backend of the checkpoint | |
| state_dict = process_state_dict(state_dict) # Remove unnecessary keys | |
| if source_backend is None: | |
| source_backend = identify_checkpoint_backend(state_dict) | |
| if source_backend == target_backend: | |
| return state_dict | |
| else: | |
| if tensor_parallel_size is None: | |
| tensor_parallel_size = model_config["tensor_parallel_size"] if is_tensor_parallel_shard else 1 | |
| # Convert to target backend | |
| if source_backend == "pytorch" and target_backend == "transformer_engine": | |
| return pytorch_to_te_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size) | |
| elif source_backend == "transformer_engine" and target_backend == "pytorch": | |
| return te_to_pytorch_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size) | |
| elif source_backend == "pytorch" and target_backend == "huggingface": | |
| return pytorch_to_hf_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size) | |
| else: | |
| raise ValueError(f"Conversion from {source_backend} to {target_backend} is not supported.") | |