File size: 833 Bytes
002bd9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import sys
import json
import os
import os.path as osp
def get_sub_model_name(ckpt_path):
ckpt_json_path = osp.join(ckpt_path, "config.json")
with open(ckpt_json_path, "r") as f:
ckpt_json = json.load(f)
return ckpt_json
def parse_sub_model(ckpt_json, sub_model_type):
if sub_model_type not in ["sam", "lm"]:
raise ValueError("sub_model_type must be one of [sam, lm], but got {}".format(sub_model_type))
if sub_model_type == "sam":
return ckpt_json["_name_or_path"]
elif sub_model_type == "lm":
return ckpt_json["text_config"]["_name_or_path"]
if __name__ == "__main__":
ckpt_path = sys.argv[1]
sub_model_type = sys.argv[2]
ckpt_json = get_sub_model_name(ckpt_path)
sub_model_name = parse_sub_model(ckpt_json, sub_model_type)
print(sub_model_name)
|