deepspeed / scripts /tools /get_sub_model_name_from_ckpt.py
xingzhikb's picture
init
002bd9b
import sys
import json
import os
import os.path as osp
def get_sub_model_name(ckpt_path):
ckpt_json_path = osp.join(ckpt_path, "config.json")
with open(ckpt_json_path, "r") as f:
ckpt_json = json.load(f)
return ckpt_json
def parse_sub_model(ckpt_json, sub_model_type):
if sub_model_type not in ["sam", "lm"]:
raise ValueError("sub_model_type must be one of [sam, lm], but got {}".format(sub_model_type))
if sub_model_type == "sam":
return ckpt_json["_name_or_path"]
elif sub_model_type == "lm":
return ckpt_json["text_config"]["_name_or_path"]
if __name__ == "__main__":
ckpt_path = sys.argv[1]
sub_model_type = sys.argv[2]
ckpt_json = get_sub_model_name(ckpt_path)
sub_model_name = parse_sub_model(ckpt_json, sub_model_type)
print(sub_model_name)