File size: 3,380 Bytes
cb0ad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import json
import tqdm
from libs.utils.teds import TEDS
from collections import defaultdict


# def parse_args():
#     import argparse
#     parser = argparse.ArgumentParser()
#     parser.add_argument('pred_path', type=str, default=None)
#     parser.add_argument('label_path', type=str, default=None)
#     parser.add_argument('-s', '--structure_only', action='store_true')
#     parser.add_argument('-n', '--num_workers', type=int, default=1)
#     args = parser.parse_args()
#     return args


def is_simple(data):
    if ('colspan' in data) or ('rowspan' in data):
        return False
    else:
        return True


def judge_type(data):
    if is_simple(data):
        return 'Simple'
    else:
        return 'Complex'


def single_process(pred_htmls, label_htmls, structure_only=False):
    evaluator = TEDS(structure_only=structure_only)
    scores = dict()
    for key in tqdm.tqdm(label_htmls.keys()):
        pred_html = pred_htmls.get(key, '')
        label_html = label_htmls[key]['html']
        score = evaluator.evaluate(pred_html, label_html)
        scores[key] = score
    return scores


def _worker(pred_htmls, label_htmls, keys, result_queue, structure_only=False):
    evaluator = TEDS(structure_only=structure_only)
    for key in keys:
        pred_html = pred_htmls.get(key, '')
        label_html = label_htmls[key]['html']
        score = evaluator.evaluate(pred_html, label_html)
        result_queue.put((key, score))


def multi_process(pred_htmls, label_htmls, num_workers, structure_only=False):
    import multiprocessing
    manager = multiprocessing.Manager()
    result_queue = manager.Queue()
    keys = list(label_htmls.keys())
    workers = list()
    for worker_idx in range(num_workers):
        worker = multiprocessing.Process(
            target=_worker,
            args=(
                pred_htmls,
                label_htmls,
                keys[worker_idx::num_workers],
                result_queue
            )
        )
        worker.daemon = True
        worker.start()
        workers.append(worker)
    scores = dict()
    tq = tqdm.tqdm(total=len(keys))
    for _ in range(len(keys)):
        key, val = result_queue.get()
        scores[key] = val
        teds = sum(scores.values()) / len(scores)
        tq.set_description('Teds: %s' % teds, False)
        tq.update()
    tq.close()
    return scores


def evaluate(pred_htmls, label_htmls, num_workers, structure_only=False):
    if num_workers <= 1:
        scores = single_process(pred_htmls, label_htmls, structure_only)
    else:
        scores = multi_process(pred_htmls, label_htmls, num_workers, structure_only)
    teds = sum(scores.values())/len(scores)

    typed_teds = defaultdict(list)
    for key, score in scores.items():
        data_type = judge_type(label_htmls[key]['html'])
        typed_teds[data_type].append(score)
    
    typed_teds = {key: sum(val)/len(val) for key, val in typed_teds.items()}
    return teds, typed_teds


# def main():
#     args = parse_args()
#     pred_data = json.load(open(args.pred_path))
#     label_data = json.load(open(args.label_path))
    
#     teds, typed_teds = evaluate(pred_data, label_data, args.num_workers, args.structure_only)
#     print('Teds: %s' % teds)
#     for key, val in typed_teds.items():
#         print('    %s Teds: %s' % (key, val))


# if __name__ == '__main__':
#     main()