File size: 3,623 Bytes
da806fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import copy
import json
import os

import mmengine
from mmengine.config import Config, ConfigDict

from opencompass.utils import build_dataset_from_cfg, get_infer_output_path


def parse_args():
    parser = argparse.ArgumentParser(
        description='Merge patitioned predictions')
    parser.add_argument('config', help='Train config file path')
    parser.add_argument('-w', '--work-dir', default=None, type=str)
    parser.add_argument('-r', '--reuse', default='latest', type=str)
    parser.add_argument('-c', '--clean', action='store_true')
    parser.add_argument('-f', '--force', action='store_true')
    args = parser.parse_args()
    return args


class PredictionMerger:

    def __init__(self, cfg: ConfigDict) -> None:
        self.cfg = cfg
        self.model_cfg = copy.deepcopy(self.cfg['model'])
        self.dataset_cfg = copy.deepcopy(self.cfg['dataset'])
        self.work_dir = self.cfg.get('work_dir')

    def run(self):
        filename = get_infer_output_path(
            self.model_cfg, self.dataset_cfg,
            os.path.join(self.work_dir, 'predictions'))
        root, ext = os.path.splitext(filename)
        partial_filename = root + '_0' + ext

        if os.path.exists(
                os.path.realpath(filename)) and not self.cfg['force']:
            return

        if not os.path.exists(os.path.realpath(partial_filename)):
            print(f'{filename} not found')
            return

        # Load predictions
        partial_filenames = []
        preds, offset = {}, 0
        i = 1
        while os.path.exists(os.path.realpath(partial_filename)):
            partial_filenames.append(os.path.realpath(partial_filename))
            _preds = mmengine.load(partial_filename)
            partial_filename = root + f'_{i}' + ext
            i += 1
            for _o in range(len(_preds)):
                preds[str(offset)] = _preds[str(_o)]
                offset += 1

        dataset = build_dataset_from_cfg(self.dataset_cfg)
        if len(preds) != len(dataset.test):
            print('length mismatch')
            return

        print(f'Merge {partial_filenames} to {filename}')
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(preds, f, indent=4, ensure_ascii=False)

        if self.cfg['clean']:
            for partial_filename in partial_filenames:
                print(f'Remove {partial_filename}')
                os.remove(partial_filename)


def dispatch_tasks(cfg):
    for model in cfg['models']:
        for dataset in cfg['datasets']:
            PredictionMerger({
                'model': model,
                'dataset': dataset,
                'work_dir': cfg['work_dir'],
                'clean': cfg['clean'],
                'force': cfg['force'],
            }).run()


def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    # set work_dir
    if args.work_dir is not None:
        cfg['work_dir'] = args.work_dir
    else:
        cfg.setdefault('work_dir', './outputs/default')

    if args.reuse:
        if args.reuse == 'latest':
            if not os.path.exists(cfg.work_dir) or not os.listdir(
                    cfg.work_dir):
                print('No previous results to reuse!')
                return
            else:
                dirs = os.listdir(cfg.work_dir)
                dir_time_str = sorted(dirs)[-1]
        else:
            dir_time_str = args.reuse
    cfg['work_dir'] = os.path.join(cfg.work_dir, dir_time_str)

    cfg['clean'] = args.clean
    cfg['force'] = args.force

    dispatch_tasks(cfg)


if __name__ == '__main__':
    main()