|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
from cosmos_predict1.utils import log |
|
|
|
|
|
|
|
|
substrings_to_ignore = [ |
|
|
"_extra_state", |
|
|
] |
|
|
|
|
|
|
|
|
def identify_checkpoint_backend(state_dict: dict[str, torch.Tensor]) -> str: |
|
|
""" |
|
|
Identify the backend of the checkpoint (PyTorch or TransformerEngine) |
|
|
|
|
|
Args: |
|
|
state_dict (dict[str, torch.Tensor]): The state dict to check |
|
|
|
|
|
Returns: |
|
|
str: The backend of the checkpoint |
|
|
""" |
|
|
for key in state_dict.keys(): |
|
|
if "self_attention.layernorm_qkv.query_weight" in key: |
|
|
return "transformer_engine" |
|
|
elif "attention.wq.weight" in key: |
|
|
return "pytorch" |
|
|
raise ValueError("Could not identify the backend of the checkpoint") |
|
|
|
|
|
|
|
|
def get_partial_state_dict( |
|
|
state_dict: dict[str, torch.Tensor], |
|
|
prefix: str, |
|
|
) -> dict[str, torch.Tensor]: |
|
|
""" |
|
|
Get a partial state dict with keys starting with the given prefix |
|
|
""" |
|
|
return {k: v for k, v in state_dict.items() if k.startswith(prefix)} |
|
|
|
|
|
|
|
|
def process_state_dict( |
|
|
state_dict: dict[str, torch.Tensor], |
|
|
device: str = None, |
|
|
dtype: torch.dtype = None, |
|
|
prefix_to_remove: Optional[str] = None, |
|
|
) -> dict[str, torch.Tensor]: |
|
|
""" |
|
|
- Remove items with substring "_extra_state" in keys (TransformerEngine adds these for FP8) |
|
|
- Move tensors to specified device and dtype if provided |
|
|
|
|
|
Args: |
|
|
state_dict (dict[str, torch.Tensor]): The state dict to process |
|
|
device (str, optional): The device to move tensors to. Defaults to None. |
|
|
dtype (torch.dtype, optional): The dtype to move tensors to. Defaults to None. |
|
|
prefix_to_remove (str, optional): The prefix to remove from the keys of the state dict. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
dict[str, torch.Tensor]: The processed state dict |
|
|
""" |
|
|
new_state_dict = {} |
|
|
tensor_kwargs = {} |
|
|
if device is not None: |
|
|
tensor_kwargs["device"] = device |
|
|
if dtype is not None: |
|
|
tensor_kwargs["dtype"] = dtype |
|
|
|
|
|
for key, value in state_dict.items(): |
|
|
|
|
|
skip = False |
|
|
for substr in substrings_to_ignore: |
|
|
if substr in key: |
|
|
skip = True |
|
|
break |
|
|
if skip: |
|
|
continue |
|
|
if len(tensor_kwargs) > 0: |
|
|
value = value.to(**tensor_kwargs) |
|
|
if prefix_to_remove is not None and key.startswith(prefix_to_remove): |
|
|
key = key[len(prefix_to_remove) :] |
|
|
new_state_dict[key] = value |
|
|
return new_state_dict |
|
|
|
|
|
|
|
|
def obtain_tensor_parallel_state_dict( |
|
|
whole_model_state_dict: dict[str, torch.Tensor], |
|
|
tensor_parallel_size: int, |
|
|
tensor_parallel_rank: int, |
|
|
model_config, |
|
|
target_backend: str = None, |
|
|
) -> dict[str, torch.Tensor]: |
|
|
""" |
|
|
Obtain the tensor parallel state dict shard for the current rank. |
|
|
|
|
|
Args: |
|
|
whole_model_state_dict (dict[str, torch.Tensor]): The complete model state dict. |
|
|
tensor_parallel_size (int): The number of tensor parallel devices. |
|
|
tensor_parallel_rank (int): The rank of the current tensor parallel device. |
|
|
model_config: The model configuration. |
|
|
target_backend (str, optional): The target backend format ('pytorch', 'transformer_engine', or 'huggingface'). If not specified, the source backend will be used. |
|
|
|
|
|
Returns: |
|
|
dict[str, torch.Tensor]: The updated state dict shard for the current tensor parallel rank. |
|
|
""" |
|
|
new_state_dict_shard = {} |
|
|
whole_model_state_dict = process_state_dict(whole_model_state_dict) |
|
|
source_backend = identify_checkpoint_backend(whole_model_state_dict) |
|
|
if source_backend != "pytorch": |
|
|
|
|
|
whole_model_state_dict = maybe_convert_checkpoint_to_backend( |
|
|
whole_model_state_dict, target_backend="pytorch", model_config=model_config, source_backend=source_backend |
|
|
) |
|
|
|
|
|
n_heads = model_config["n_heads"] |
|
|
n_kv_heads = model_config["n_kv_heads"] |
|
|
dim = model_config["dim"] |
|
|
context_dim = model_config["context_dim"] |
|
|
for key, value in whole_model_state_dict.items(): |
|
|
prefix = "model." if key.startswith("model.") else "" |
|
|
prefix = "transformer." if key.startswith("transformer.") else prefix |
|
|
key = key.replace(prefix, "") |
|
|
if key.startswith("layers."): |
|
|
layer_index = int(key.split("layers.")[1].split(".")[0]) |
|
|
if layer_index >= model_config["n_layers"]: |
|
|
log.warning( |
|
|
f"Layer index {layer_index} is greater than the number of layers {model_config['n_layers']}. Skipping this layer." |
|
|
) |
|
|
continue |
|
|
if ".attention.wq.weight" in key or "cross_attention.wq.weight" in key: |
|
|
value = torch.chunk(value.view(n_heads, -1, dim), tensor_parallel_size, dim=0)[tensor_parallel_rank] |
|
|
value = value.reshape(-1, dim) |
|
|
elif ".attention.wk.weight" in key or ".attention.wv.weight" in key: |
|
|
value = torch.chunk(value.view(n_kv_heads, -1, dim), tensor_parallel_size, dim=0)[tensor_parallel_rank] |
|
|
value = value.reshape(-1, dim) |
|
|
elif "cross_attention.wk.weight" in key or "cross_attention.wv.weight" in key: |
|
|
assert context_dim is not None |
|
|
value = torch.chunk(value.view(n_kv_heads, -1, context_dim), tensor_parallel_size, dim=0)[ |
|
|
tensor_parallel_rank |
|
|
] |
|
|
value = value.reshape(-1, context_dim) |
|
|
elif "feed_forward.w1.weight" in key or "feed_forward.w3.weight" in key or "medusa_head" in key: |
|
|
value = torch.chunk(value, tensor_parallel_size, dim=0)[tensor_parallel_rank] |
|
|
elif "feed_forward.w2.weight" in key or ".attention.wo.weight" in key or "cross_attention.wo.weight" in key: |
|
|
value = torch.chunk(value, tensor_parallel_size, dim=1)[tensor_parallel_rank] |
|
|
else: |
|
|
|
|
|
if key == "tok_embeddings.weight" or key == "output.weight" or "medusa_head" in key: |
|
|
value = torch.chunk(value, tensor_parallel_size, dim=0)[tensor_parallel_rank] |
|
|
new_state_dict_shard[prefix + key] = value |
|
|
|
|
|
if target_backend is None: |
|
|
target_backend = source_backend |
|
|
|
|
|
new_state_dict_shard = maybe_convert_checkpoint_to_backend( |
|
|
new_state_dict_shard, |
|
|
target_backend=target_backend, |
|
|
model_config=model_config, |
|
|
is_tensor_parallel_shard=True, |
|
|
tensor_parallel_size=tensor_parallel_size, |
|
|
) |
|
|
|
|
|
return new_state_dict_shard |
|
|
|
|
|
|
|
|
def merge_tensor_parallel_state_dicts( |
|
|
state_dict_shards: list[dict[str, torch.Tensor]], |
|
|
model_config, |
|
|
target_backend: str = None, |
|
|
) -> dict[str, torch.Tensor]: |
|
|
""" |
|
|
Merge tensor parallel state dict shards into a whole model state dict. |
|
|
|
|
|
Args: |
|
|
state_dict_shards (List[Dict[str, torch.Tensor]]): The list of state dict shards to merge. |
|
|
model_config: The model configuration. |
|
|
target_backend (str, optional): The target backend format ('pytorch', 'transformer_engine', or 'huggingface'). If not specified, the source backend will be used. |
|
|
|
|
|
Returns: |
|
|
Dict[str, torch.Tensor]: The merged state dict. |
|
|
""" |
|
|
state_dict_shards = [process_state_dict(shard, device="cpu") for shard in state_dict_shards] |
|
|
tensor_parallel_size = len(state_dict_shards) |
|
|
source_backend = identify_checkpoint_backend(state_dict_shards[0]) |
|
|
if source_backend != "pytorch": |
|
|
log.critical(f"Converting from {source_backend} to PyTorch backend for tensor parallel checkpoint merging.") |
|
|
state_dict_shards = [ |
|
|
maybe_convert_checkpoint_to_backend( |
|
|
shard, |
|
|
target_backend="pytorch", |
|
|
model_config=model_config, |
|
|
source_backend=source_backend, |
|
|
is_tensor_parallel_shard=True, |
|
|
tensor_parallel_size=tensor_parallel_size, |
|
|
) |
|
|
for shard in state_dict_shards |
|
|
] |
|
|
|
|
|
n_heads = model_config["n_heads"] |
|
|
n_kv_heads = model_config["n_kv_heads"] |
|
|
n_local_heads = n_heads // tensor_parallel_size |
|
|
n_local_kv_heads = n_kv_heads // tensor_parallel_size |
|
|
dim = model_config["dim"] |
|
|
context_dim = model_config["context_dim"] |
|
|
head_dim = model_config["head_dim"] |
|
|
if head_dim is None: |
|
|
head_dim = model_config["dim"] // model_config["n_heads"] |
|
|
query_dim = head_dim * n_heads |
|
|
key_value_dim = head_dim * n_kv_heads |
|
|
merged_state_dict = {} |
|
|
|
|
|
for key in state_dict_shards[0].keys(): |
|
|
prefix = "model." if key.startswith("model.") else "" |
|
|
key_without_prefix = key[len(prefix) :] |
|
|
if key_without_prefix.startswith("layers."): |
|
|
layer_index = int(key_without_prefix.split("layers.")[1].split(".")[0]) |
|
|
if layer_index >= model_config["n_layers"]: |
|
|
log.warning( |
|
|
f"Layer index {layer_index} is greater than the number of layers {model_config['n_layers']}. Skipping this layer." |
|
|
) |
|
|
continue |
|
|
if key_without_prefix == "tok_embeddings.weight" or key_without_prefix == "output.weight": |
|
|
merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=0) |
|
|
elif ".attention.wq.weight" in key or "cross_attention.wq.weight" in key: |
|
|
chunks = [shard[key].view(n_local_heads, head_dim, dim) for shard in state_dict_shards] |
|
|
merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(query_dim, dim) |
|
|
elif ".attention.wk.weight" in key or ".attention.wv.weight" in key: |
|
|
chunks = [shard[key].view(n_local_kv_heads, head_dim, dim) for shard in state_dict_shards] |
|
|
merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(key_value_dim, dim) |
|
|
elif "cross_attention.wk.weight" in key or "cross_attention.wv.weight" in key: |
|
|
chunks = [shard[key].view(n_local_kv_heads, head_dim, context_dim) for shard in state_dict_shards] |
|
|
merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(key_value_dim, context_dim) |
|
|
elif "feed_forward.w1.weight" in key or "feed_forward.w3.weight" in key or "medusa_head" in key: |
|
|
merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=0) |
|
|
elif "feed_forward.w2.weight" in key or ".attention.wo.weight" in key or "cross_attention.wo.weight" in key: |
|
|
merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=1) |
|
|
else: |
|
|
avg_tensor = torch.stack([shard[key] for shard in state_dict_shards]).mean(dim=0) |
|
|
|
|
|
assert torch.allclose(state_dict_shards[0][key], avg_tensor, atol=5e-2, rtol=0.1), ( |
|
|
f"Shard-0 tensor {key} is not close to the average tensor. " |
|
|
f"Max diff: {torch.max(torch.abs(state_dict_shards[0][key] - avg_tensor))}, " |
|
|
) |
|
|
merged_state_dict[key] = avg_tensor |
|
|
assert "norm" in key, f"Assumed the key {key} is a norm layer, which should be the same across shards." |
|
|
|
|
|
if target_backend is None: |
|
|
target_backend = source_backend |
|
|
return maybe_convert_checkpoint_to_backend( |
|
|
merged_state_dict, target_backend=target_backend, model_config=model_config |
|
|
) |
|
|
|
|
|
|
|
|
def te_to_pytorch_state_dict( |
|
|
te_state_dict: Dict[str, torch.Tensor], model_config, tensor_parallel_size: int = 1 |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Convert a TransformerEngine state dict to PyTorch state dict |
|
|
|
|
|
Args: |
|
|
te_state_dict (Mapping[str, torch.Tensor]): The TransformerEngine state dict |
|
|
model_config: The model configuration |
|
|
tensor_parallel_size (int): The tensor parallel size. Defaults to 1 (i.e., not a tensor parallel shard). |
|
|
|
|
|
Returns: |
|
|
Mapping[str, torch.Tensor]: The PyTorch state dict |
|
|
""" |
|
|
|
|
|
if hasattr(model_config, "asdict"): |
|
|
model_config = model_config.asdict() |
|
|
|
|
|
pytorch_state_dict = {} |
|
|
replacement_rules = [ |
|
|
|
|
|
(".self_attention.layernorm_qkv.layer_norm_weight", ".attention_norm.weight"), |
|
|
(".self_attention.layernorm_qkv.query_weight", ".attention.wq.weight"), |
|
|
(".self_attention.layernorm_qkv.key_weight", ".attention.wk.weight"), |
|
|
(".self_attention.layernorm_qkv.value_weight", ".attention.wv.weight"), |
|
|
(".self_attention.proj.weight", ".attention.wo.weight"), |
|
|
(".self_attention.", ".attention."), |
|
|
|
|
|
(".layernorm_mlp.layer_norm_weight", ".ffn_norm.weight"), |
|
|
(".layernorm_mlp.fc2_weight", ".feed_forward.w2.weight"), |
|
|
|
|
|
(".inter_attention.layernorm_query.query_weight", ".cross_attention.wq.weight"), |
|
|
(".inter_attention.key_value.key_weight", ".cross_attention.wk.weight"), |
|
|
(".inter_attention.key_value.value_weight", ".cross_attention.wv.weight"), |
|
|
(".inter_attention.proj.weight", ".cross_attention.wo.weight"), |
|
|
(".inter_attention.layernorm_query.layer_norm_weight", ".cross_attention_norm.weight"), |
|
|
(".inter_attention.", ".cross_attention."), |
|
|
] |
|
|
head_dim = model_config["head_dim"] |
|
|
if head_dim is None: |
|
|
head_dim = model_config["dim"] // model_config["n_heads"] |
|
|
for old_key, value in te_state_dict.items(): |
|
|
new_key = old_key |
|
|
for old_substr, new_substr in replacement_rules: |
|
|
if old_substr in new_key: |
|
|
new_key = new_key.replace(old_substr, new_substr) |
|
|
break |
|
|
|
|
|
|
|
|
if "layernorm_mlp.fc1_weight" in old_key: |
|
|
fused_weight = value |
|
|
split_point = fused_weight.shape[0] // 2 |
|
|
w1_weight = fused_weight[:split_point] |
|
|
w3_weight = fused_weight[split_point:] |
|
|
|
|
|
w1_key = new_key.replace("layernorm_mlp.fc1_weight", "feed_forward.w1.weight") |
|
|
w3_key = new_key.replace("layernorm_mlp.fc1_weight", "feed_forward.w3.weight") |
|
|
|
|
|
pytorch_state_dict[w1_key] = w1_weight |
|
|
pytorch_state_dict[w3_key] = w3_weight |
|
|
else: |
|
|
if model_config["pytorch_rope_version"] == "v1": |
|
|
|
|
|
|
|
|
if "query_weight" in old_key: |
|
|
value = inverse_permute_weight( |
|
|
value, |
|
|
n_heads=model_config["n_heads"] // tensor_parallel_size, |
|
|
dim1=head_dim * model_config["n_heads"] // tensor_parallel_size, |
|
|
dim2=model_config["dim"], |
|
|
) |
|
|
elif "key_weight" in old_key: |
|
|
value = inverse_permute_weight( |
|
|
value, |
|
|
n_heads=model_config["n_kv_heads"] // tensor_parallel_size, |
|
|
dim1=head_dim * model_config["n_kv_heads"] // tensor_parallel_size, |
|
|
dim2=model_config["context_dim"] if "inter_attention" in old_key else model_config["dim"], |
|
|
) |
|
|
pytorch_state_dict[new_key] = value |
|
|
|
|
|
return pytorch_state_dict |
|
|
|
|
|
|
|
|
def pytorch_to_te_state_dict( |
|
|
pytorch_state_dict: Dict[str, torch.Tensor], model_config, tensor_parallel_size: int = 1 |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Convert a PyTorch state dict to TransformerEngine state dict |
|
|
|
|
|
Args: |
|
|
pytorch_state_dict (Mapping[str, torch.Tensor]): The PyTorch state dict |
|
|
model_config: The model configuration |
|
|
tensor_parallel_size (int): The tensor parallel size. Defaults to 1 (i.e., not a tensor parallel shard). |
|
|
|
|
|
Returns: |
|
|
Mapping[str, torch.Tensor]: The TransformerEngine |
|
|
""" |
|
|
|
|
|
if hasattr(model_config, "asdict"): |
|
|
model_config = model_config.asdict() |
|
|
|
|
|
te_state_dict = {} |
|
|
|
|
|
replacement_rules = [ |
|
|
|
|
|
(".attention_norm.weight", ".self_attention.layernorm_qkv.layer_norm_weight"), |
|
|
(".attention.wq.weight", ".self_attention.layernorm_qkv.query_weight"), |
|
|
(".attention.wk.weight", ".self_attention.layernorm_qkv.key_weight"), |
|
|
(".attention.wv.weight", ".self_attention.layernorm_qkv.value_weight"), |
|
|
(".attention.wo.weight", ".self_attention.proj.weight"), |
|
|
(".attention.", ".self_attention."), |
|
|
|
|
|
(".ffn_norm.weight", ".layernorm_mlp.layer_norm_weight"), |
|
|
(".feed_forward.w2.weight", ".layernorm_mlp.fc2_weight"), |
|
|
|
|
|
(".cross_attention_norm.weight", ".inter_attention.layernorm_query.layer_norm_weight"), |
|
|
(".cross_attention.wq.weight", ".inter_attention.layernorm_query.query_weight"), |
|
|
(".cross_attention.wk.weight", ".inter_attention.key_value.key_weight"), |
|
|
(".cross_attention.wv.weight", ".inter_attention.key_value.value_weight"), |
|
|
(".cross_attention.wo.weight", ".inter_attention.proj.weight"), |
|
|
(".cross_attention.", ".inter_attention."), |
|
|
] |
|
|
head_dim = model_config["head_dim"] |
|
|
if head_dim is None: |
|
|
head_dim = model_config["dim"] // model_config["n_heads"] |
|
|
for old_key, value in pytorch_state_dict.items(): |
|
|
new_key = old_key |
|
|
for new_substr, old_substr in replacement_rules: |
|
|
if new_substr in new_key: |
|
|
new_key = new_key.replace(new_substr, old_substr) |
|
|
break |
|
|
|
|
|
|
|
|
if "feed_forward.w1.weight" in old_key: |
|
|
w1_weight = value |
|
|
w3_key = old_key.replace("feed_forward.w1.weight", "feed_forward.w3.weight") |
|
|
if w3_key in pytorch_state_dict: |
|
|
w3_weight = pytorch_state_dict[w3_key] |
|
|
fused_weight = torch.cat([w1_weight, w3_weight], dim=0) |
|
|
new_key = new_key.replace("feed_forward.w1.weight", "layernorm_mlp.fc1_weight") |
|
|
te_state_dict[new_key] = fused_weight |
|
|
else: |
|
|
te_state_dict[new_key] = value |
|
|
elif "feed_forward.w3.weight" in old_key: |
|
|
|
|
|
continue |
|
|
else: |
|
|
if model_config["pytorch_rope_version"] == "v1": |
|
|
|
|
|
|
|
|
if "attention.wq" in old_key: |
|
|
value = permute_weight( |
|
|
value, |
|
|
n_heads=model_config["n_heads"] // tensor_parallel_size, |
|
|
dim1=head_dim * model_config["n_heads"] // tensor_parallel_size, |
|
|
dim2=model_config["dim"], |
|
|
) |
|
|
elif "attention.wk" in old_key: |
|
|
value = permute_weight( |
|
|
value, |
|
|
n_heads=model_config["n_kv_heads"] // tensor_parallel_size, |
|
|
dim1=head_dim * model_config["n_kv_heads"] // tensor_parallel_size, |
|
|
dim2=model_config["context_dim"] if "cross_attention" in old_key else model_config["dim"], |
|
|
) |
|
|
te_state_dict[new_key] = value |
|
|
|
|
|
return te_state_dict |
|
|
|
|
|
|
|
|
def permute_weight(w: torch.Tensor, n_heads: int, dim1: int, dim2: int) -> torch.Tensor: |
|
|
""" |
|
|
Helper function for converting checkpoints from PyTorch to TransformerEngine |
|
|
Permute the query weight or key weight of each attention layer |
|
|
Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py |
|
|
|
|
|
Args: |
|
|
w (torch.Tensor): The weight tensor to permute |
|
|
n_heads (int): The number of attention heads |
|
|
dim1 (int): The first dimension of the weight tensor |
|
|
dim2 (int): The second dimension of the weight tensor |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The permuted weight tensor |
|
|
""" |
|
|
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) |
|
|
|
|
|
|
|
|
def inverse_permute_weight(w: torch.Tensor, n_heads: int, dim1: int, dim2: int) -> torch.Tensor: |
|
|
""" |
|
|
Helper function for converting checkpoints from TransformerEngine to PyTorch |
|
|
Permute the query weight or key weight of each attention layer |
|
|
|
|
|
Args: |
|
|
w (torch.Tensor): The weight tensor to permute |
|
|
n_heads (int): The number of attention heads |
|
|
dim1 (int): The first dimension of the weight tensor |
|
|
dim2 (int): The second dimension of the weight tensor |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The permuted weight tensor |
|
|
""" |
|
|
return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) |
|
|
|
|
|
|
|
|
def pytorch_to_hf_state_dict( |
|
|
state_dict: Dict[str, torch.Tensor], model_config: Dict[str, Any], tensor_parallel_size: int = 1 |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Convert a PyTorch state dict to HuggingFace format for LLM models. |
|
|
|
|
|
Args: |
|
|
state_dict (Mapping[str, torch.Tensor]): |
|
|
The original PyTorch model's state dictionary. |
|
|
This is a mapping where keys are layer names and values are the corresponding PyTorch tensors |
|
|
containing the model weights. |
|
|
|
|
|
model_config (Mapping[str, Any]): |
|
|
The configuration of the model. This dictionary contains parameters such as: |
|
|
- n_layers: (int) The number of transformer layers. |
|
|
- n_heads: (int) The number of attention heads. |
|
|
- dim: (int) The hidden size of the model. |
|
|
- n_kv_heads: (int, optional) The number of key-value heads for multi-query attention. |
|
|
|
|
|
Returns: |
|
|
Mapping[str, torch.Tensor]: |
|
|
The converted HuggingFace state dictionary. This dictionary maps HuggingFace transformer-compatible |
|
|
layer names to the corresponding model weights. |
|
|
""" |
|
|
not_supported_key_substrings = ["cross_attention", "q_norm", "k_norm"] |
|
|
for key in state_dict.keys(): |
|
|
if any(substr in key for substr in not_supported_key_substrings): |
|
|
raise ValueError(f"Key {key} is not supported in HuggingFace format.") |
|
|
assert tensor_parallel_size == 1, "Tensor parallel size > 1 is not supported for HuggingFace model export." |
|
|
|
|
|
hf_state_dict = {} |
|
|
|
|
|
n_layers = model_config["n_layers"] |
|
|
n_heads = model_config["n_heads"] |
|
|
dim = model_config["dim"] |
|
|
head_dim = model_config["head_dim"] |
|
|
if head_dim is None: |
|
|
head_dim = model_config["dim"] // model_config["n_heads"] |
|
|
|
|
|
num_key_value_heads = model_config.get("n_kv_heads", n_heads) |
|
|
key_value_dim = head_dim * num_key_value_heads |
|
|
|
|
|
for layer_i in range(n_layers): |
|
|
pt_prefix = f"layers.{layer_i}." |
|
|
hf_prefix = f"model.layers.{layer_i}." |
|
|
|
|
|
wq = state_dict[f"{pt_prefix}attention.wq.weight"] |
|
|
wk = state_dict[f"{pt_prefix}attention.wk.weight"] |
|
|
if model_config["pytorch_rope_version"] == "v1": |
|
|
wq = permute_weight( |
|
|
wq, |
|
|
n_heads=n_heads, |
|
|
dim1=dim, |
|
|
dim2=dim, |
|
|
) |
|
|
wk = permute_weight( |
|
|
wk, |
|
|
n_heads=num_key_value_heads, |
|
|
dim1=key_value_dim, |
|
|
dim2=dim, |
|
|
) |
|
|
hf_state_dict[f"{hf_prefix}self_attn.q_proj.weight"] = wq |
|
|
hf_state_dict[f"{hf_prefix}self_attn.k_proj.weight"] = wk |
|
|
hf_state_dict[f"{hf_prefix}self_attn.v_proj.weight"] = state_dict[f"{pt_prefix}attention.wv.weight"] |
|
|
hf_state_dict[f"{hf_prefix}self_attn.o_proj.weight"] = state_dict[f"{pt_prefix}attention.wo.weight"] |
|
|
hf_state_dict[f"{hf_prefix}mlp.gate_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w1.weight"] |
|
|
hf_state_dict[f"{hf_prefix}mlp.down_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w2.weight"] |
|
|
hf_state_dict[f"{hf_prefix}mlp.up_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w3.weight"] |
|
|
hf_state_dict[f"{hf_prefix}input_layernorm.weight"] = state_dict[f"{pt_prefix}attention_norm.weight"] |
|
|
hf_state_dict[f"{hf_prefix}post_attention_layernorm.weight"] = state_dict[f"{pt_prefix}ffn_norm.weight"] |
|
|
|
|
|
|
|
|
hf_state_dict["model.embed_tokens.weight"] = state_dict["tok_embeddings.weight"] |
|
|
hf_state_dict["model.norm.weight"] = state_dict["norm.weight"] |
|
|
hf_state_dict["lm_head.weight"] = state_dict["output.weight"] |
|
|
|
|
|
return hf_state_dict |
|
|
|
|
|
|
|
|
def maybe_convert_checkpoint_to_backend( |
|
|
state_dict: Dict[str, torch.Tensor], |
|
|
target_backend: str, |
|
|
model_config, |
|
|
source_backend: str = None, |
|
|
is_tensor_parallel_shard: bool = False, |
|
|
tensor_parallel_size: int = None, |
|
|
): |
|
|
""" |
|
|
Identify the backend of the checkpoint and convert to the target backend if necessary. |
|
|
|
|
|
This function checks the current backend of the state_dict and converts it to the target backend |
|
|
if they don't match. It supports conversions between PyTorch, TransformerEngine, and HuggingFace backends. |
|
|
|
|
|
Args: |
|
|
state_dict (Dict[str, torch.Tensor]): The model state dictionary to convert. |
|
|
target_backend (str): The desired backend format ('pytorch', 'transformer_engine', or 'huggingface'). |
|
|
model_config: Configuration of the model, used in conversion process. |
|
|
source_backend (str, optional): The current backend of the state_dict. If not specified, the function will identify the backend. |
|
|
is_tensor_parallel_shard (bool, optional): Whether the state_dict is a tensor parallel shard. Defaults to False. |
|
|
tensor_parallel_size (int, optional): The tensor parallel size. If not specified, the model_config will be modified. |
|
|
Returns: |
|
|
Dict[str, torch.Tensor]: The converted state dictionary in the target backend format. |
|
|
|
|
|
Raises: |
|
|
ValueError: If the conversion between the identified backend and target backend is not supported. |
|
|
""" |
|
|
|
|
|
state_dict = process_state_dict(state_dict) |
|
|
if source_backend is None: |
|
|
source_backend = identify_checkpoint_backend(state_dict) |
|
|
if source_backend == target_backend: |
|
|
return state_dict |
|
|
else: |
|
|
if tensor_parallel_size is None: |
|
|
tensor_parallel_size = model_config["tensor_parallel_size"] if is_tensor_parallel_shard else 1 |
|
|
|
|
|
if source_backend == "pytorch" and target_backend == "transformer_engine": |
|
|
return pytorch_to_te_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size) |
|
|
elif source_backend == "transformer_engine" and target_backend == "pytorch": |
|
|
return te_to_pytorch_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size) |
|
|
elif source_backend == "pytorch" and target_backend == "huggingface": |
|
|
return pytorch_to_hf_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size) |
|
|
else: |
|
|
raise ValueError(f"Conversion from {source_backend} to {target_backend} is not supported.") |
|
|
|