File size: 2,518 Bytes
8520896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import json
import argparse
import nltk


def save_json_pretty(data, file_path):
    """save formatted json, use this one for some json config files"""
    with open(file_path, "w") as f:
        f.write(json.dumps(data, indent=4, sort_keys=True))


def load_json(file_path):
    with open(file_path, "r") as f:
        return json.load(f)


def flat_list_of_lists(l):
    """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
    return [item for sublist in l for item in sublist]


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-s", "--submission", type=str, help="submission file")
    parser.add_argument("-v", "--verbose", action="store_true", help="print info")
    parser.add_argument("-r", "--reference", type=str, help="GT reference, used to collect the video ids")
    parser.add_argument("-o", "--output", type=str, help="save path")
    return parser.parse_args()


def get_sen_stat(list_of_str):
    """list_of_str, list(str), str could be a sentence a paragraph"""
    tokenized = [nltk.tokenize.word_tokenize(sen.lower()) for sen in list_of_str]
    num_sen = len(list_of_str)
    lengths = [len(e) for e in tokenized]
    avg_len = 1.0 * sum(lengths) / len(lengths)
    full_vocab = set(flat_list_of_lists(tokenized))
    return {"vocab_size": len(full_vocab), "avg_sen_len": avg_len, "num_sen": num_sen}


def eval_cap():
    """Get vocab size, average length, etc """
    args = get_args()

    # load data
    sub_data = json.load(open(args.submission, "r"))
    ref_data = json.load(open(args.reference, "r"))
    sub_data = sub_data["results"] if "results" in sub_data else sub_data
    ref_data = ref_data["results"] if "results" in ref_data else ref_data
    sub_data = {k: v for k, v in sub_data.items() if k in ref_data}

    submission_data_entries = flat_list_of_lists(sub_data.values())
    submission_sentences = [e["sentence"] for e in submission_data_entries]
    submission_stat = get_sen_stat(submission_sentences)

    if args.verbose:
        for k in submission_stat:
            print("{} submission {}".format(k, submission_stat[k]))
    final_res = {"submission": submission_stat}

    if "gt_sentence" in submission_data_entries[0]:
        gt_sentences = [e["gt_sentence"] for e in submission_data_entries]
        gt_stat = get_sen_stat(gt_sentences)  # only one reference is used here!!!
        final_res["gt_stat"] = gt_stat

    save_json_pretty(final_res, args.output)
    return final_res

if __name__ == '__main__':
    eval_cap()