| """ |
| This script converts the weights from LLAMA to OpenLM compatible weights. |
| Usage: `python convert_llama_to_openlm.py <llama_weight_path> <openlm_weight_path>` |
| """ |
|
|
| import torch |
| import sys |
|
|
|
|
| def convert(llama_state_dict: dict) -> dict: |
| openlm_state_dict = {} |
|
|
| n_layer = len(set([key.split(".")[1] for key in llama_state_dict if "layers." in key])) |
| print(f"n_layer: {n_layer}") |
|
|
| for key in ["tok_embeddings.weight", "norm.weight", "output.weight"]: |
| value = llama_state_dict[key] |
| assert key not in openlm_state_dict |
| openlm_state_dict[key] = value |
|
|
| for i in range(n_layer): |
| src_key_1, src_key_2, src_key_3 = ( |
| f"layers.{i}.attention.wq.weight", |
| f"layers.{i}.attention.wk.weight", |
| f"layers.{i}.attention.wv.weight", |
| ) |
| tgt_key = f"layers.{i}.attention.in_proj.weight" |
| assert tgt_key not in openlm_state_dict |
| openlm_state_dict[tgt_key] = torch.cat( |
| [ |
| llama_state_dict[src_key_1], |
| llama_state_dict[src_key_2], |
| llama_state_dict[src_key_3], |
| ], |
| dim=0, |
| ) |
|
|
| src_key = f"layers.{i}.attention.wo.weight" |
| tgt_key = f"layers.{i}.attention.out_proj.weight" |
| assert tgt_key not in openlm_state_dict |
| openlm_state_dict[tgt_key] = llama_state_dict[src_key] |
|
|
| src_key_1, src_key_2 = ( |
| f"layers.{i}.feed_forward.w1.weight", |
| f"layers.{i}.feed_forward.w3.weight", |
| ) |
| tgt_key = f"layers.{i}.feed_forward.w12.weight" |
| assert tgt_key not in openlm_state_dict |
| openlm_state_dict[tgt_key] = torch.cat([llama_state_dict[src_key_1], llama_state_dict[src_key_2]], dim=0) |
|
|
| src_key = f"layers.{i}.feed_forward.w2.weight" |
| tgt_key = f"layers.{i}.feed_forward.w3.weight" |
| assert tgt_key not in openlm_state_dict |
| openlm_state_dict[tgt_key] = llama_state_dict[src_key] |
|
|
| tgt_key = f"layers.{i}.attention_norm.weight" |
| assert tgt_key not in openlm_state_dict |
| openlm_state_dict[tgt_key] = llama_state_dict[tgt_key] |
|
|
| tgt_key = f"layers.{i}.ffn_norm.weight" |
| assert tgt_key not in openlm_state_dict |
| openlm_state_dict[tgt_key] = llama_state_dict[tgt_key] |
|
|
| return openlm_state_dict |
|
|
|
|
| if __name__ == "__main__": |
| if len(sys.argv) != 3: |
| print("Usage: `python convert_llama_to_openlm.py <llama_weight_path> <openlm_weight_path>`") |
| sys.exit(1) |
| llama_state_dict = torch.load(sys.argv[1]) |
| openlm_state_dict = {"state_dict": convert(llama_state_dict)} |
| torch.save(openlm_state_dict, sys.argv[2]) |
|
|