| | |
| | |
| |
|
| | import torch |
| | full_state_dict = torch.load("./pytorch_model.bin") |
| | full_state_dict = dict((".".join(k.split(".")[1:]), v) \ |
| | for k, v in full_state_dict.items()) |
| |
|
| | def con_cat(kqv_dict): |
| | kqv_dict_keys = list(kqv_dict.keys()) |
| | if "weight" in kqv_dict_keys[0]: |
| | tmp = kqv_dict_keys[0].split(".")[3] |
| | c_dict_value = torch.cat([kqv_dict[kqv_dict_keys[0].replace(tmp, "q_proj")], |
| | kqv_dict[kqv_dict_keys[0].replace(tmp, "k_proj")], |
| | kqv_dict[kqv_dict_keys[0].replace(tmp, "v_proj")] |
| | ]) |
| | c_dict_key = ".".join(kqv_dict_keys[0].split(".")[:3]+["in_proj_weight"]) |
| | |
| | return {f"encoder.{c_dict_key}":c_dict_value} |
| |
|
| | |
| | if "bias" in kqv_dict_keys[0]: |
| | tmp = kqv_dict_keys[0].split(".")[3] |
| | c_dict_value = torch.cat([kqv_dict[kqv_dict_keys[0].replace(tmp, "q_proj")], |
| | kqv_dict[kqv_dict_keys[0].replace(tmp, "k_proj")], |
| | kqv_dict[kqv_dict_keys[0].replace(tmp, "v_proj")] |
| | ]) |
| | c_dict_key = ".".join(kqv_dict_keys[0].split(".")[:3]+["in_proj_bias"]) |
| | |
| | return {f"encoder.{c_dict_key}":c_dict_value} |
| |
|
| |
|
| | mod_dict = {} |
| | |
| | for k, v in full_state_dict.items(): |
| | if "embedding" in k or "layer_norm" in k: |
| | mod_dict.update({f"embeddings.{k}": v}) |
| |
|
| | |
| | for i in range(12): |
| | sd = dict((k, v) for k, v in full_state_dict.items() if f"layers.{i}" in k) |
| | kvq_weight = {} |
| | kvq_bias = {} |
| | for k, v in sd.items(): |
| | if "self_attn" in k and "out_proj" not in k: |
| | if "weight" in k: |
| | kvq_weight[k] = v |
| | if "bias" in k: |
| | kvq_bias[k] = v |
| | else: |
| | mod_dict[f"encoder.{k}"] = v |
| |
|
| | mod_dict.update(con_cat(kvq_weight)) |
| | mod_dict.update(con_cat(kvq_bias)) |
| |
|
| | |
| | for k, v in full_state_dict.items(): |
| | if "dense" in k: |
| | mod_dict.update({f"pooler.{k}":v}) |
| |
|
| |
|
| | for k, v in mod_dict.items(): |
| | print(k, v.size()) |
| |
|
| | model_name = "ernie-m-base_pytorch" |
| | PATH = f"./{model_name}/pytorch_model.bin" |
| | torch.save(mod_dict, PATH) |