File size: 833 Bytes
002bd9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import sys
import json
import os
import os.path as osp


def get_sub_model_name(ckpt_path):
    ckpt_json_path = osp.join(ckpt_path, "config.json")

    with open(ckpt_json_path, "r") as f:
        ckpt_json = json.load(f)

    return ckpt_json


def parse_sub_model(ckpt_json, sub_model_type):
    if sub_model_type not in ["sam", "lm"]:
        raise ValueError("sub_model_type must be one of [sam, lm], but got {}".format(sub_model_type))

    if sub_model_type == "sam":
        return ckpt_json["_name_or_path"]
    elif sub_model_type == "lm":
        return ckpt_json["text_config"]["_name_or_path"]


if __name__ == "__main__":
    ckpt_path = sys.argv[1]
    sub_model_type = sys.argv[2]
    ckpt_json = get_sub_model_name(ckpt_path)
    sub_model_name = parse_sub_model(ckpt_json, sub_model_type)
    print(sub_model_name)