File size: 6,265 Bytes
91daf98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import json
import os
import time
import argparse

from openai_utils import ask_gpt_on_figure, ask_gpt
from llava_utils import ask_llm, ask_llm_on_figure, restart_model
from tqdm import tqdm


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--source-data-path", type=str, required=True)
    parser.add_argument("--figure-path", type=str, required=True)
    parser.add_argument("--save-path", type=str, required=True)
    parser.add_argument("--num-samples", type=int, required=True)
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument("--score-only", action="store_true", default=False)
    parser.add_argument("--gpt", action="store_true", default=False)
    args = parser.parse_args()
    
    source_path = args.source_data_path
    folder_path = args.figure_path
    save_path = args.save_path
    num_samples = args.num_samples
    device=f'cuda:{args.gpu}'
    if args.gpt:
        func1, func2 = ask_gpt_on_figure, ask_gpt
        model = None
        processor = None
    else:
        func1, func2 = ask_llm_on_figure, ask_llm
        model, processor = restart_model(device)
    
    with open(source_path, 'r') as f:
        test_data = json.load(f)
    
    ####### Stage 1 #######
    # for model generations that are able to render pictures,
    # ask gpt to rate the generation quality.
    for data in tqdm(test_data):
        file_id = str(data['index']).zfill(6)
        file = None
        for f in os.listdir(folder_path):
            if f.startswith(file_id):
                file = folder_path + f
                data['figure_path'] = file
                error_cnt = 0
                while 1:
                    try:
                        data['gpt_label'] = func1(data, model, processor)
                        break
                    except Exception as e:
                        print(e)
                        if args.gpt:
                            time.sleep(3)
                        else:
                            if error_cnt == 5:
                                exit()
                            model, processor = restart_model(device)
                    error_cnt += 1
    with open(save_path, 'w+') as f:
        json.dump(test_data, f, indent=4)   
    
    with open(save_path, 'r') as f:
        test_data = json.load(f)
    ####### Stage 2 #######
    # clean up the dataset to summarize the generation quality estimation to a numerical score, and
    # remove the failed ones, i.e. the generations that cannot render
    for data in tqdm(test_data):
        if "gpt_label" in data.keys():
            error_cnt = 0
            while 1:
                try:
                    score = func2(data, model, processor)
                    print(score)
                    break
                except Exception as e:
                    print(e)
                    if args.gpt:
                        time.sleep(3)
                    else:
                        if error_cnt == 5:
                            exit()
                        model, processor = restart_model(device)
                error_cnt += 1
            try:
                data['gpt_score'] = int(score)
            except:
                print(f'ERROR: {score}')
                pass    
                
    saved_data = [data for data in test_data if 'gpt_score' in data.keys()]     
    with open(save_path, 'w+') as f:
        json.dump(saved_data, f, indent=4)   
    
    if args.score_only:
        exit()
    
    ####### Stage 3 #######
    # 1. group up the scored generations by their description: we do not compare 
    #    generation results that come from different origin prompts
    temp_data = []
    max_idx = test_data[-1]['index']
    sample_size = max_idx // num_samples + 1
    # a. select if any above 6

    # for i in range(sample_size):
    #     next_sample = test_data[i*num_samples:(i+1)*num_samples]
    #     next_sample = [item for item in next_sample if 'gpt_score' in item.keys()]
    #     above_score = [item['gpt_score'] >= 6 for item in next_sample]
    #     if any(above_score):
    #         temp_data.extend(next_sample)
    # temp_data = [data for data in temp_data if 'gpt_score' in data.keys()] 
    
    # b. select if avg above 6

    # for i in range(sample_size):
    #     next_sample = test_data[i*num_samples:(i+1)*num_samples]
    #     next_sample = [item for item in next_sample if 'gpt_score' in item.keys()]
    #     if len(next_sample) == 0:
    #         continue
    #     scores = sum(item['gpt_score'] for item in next_sample) / len(next_sample)
    #     if scores >= 6:
    #         temp_data.extend(next_sample)
    # temp_data = [data for data in temp_data if 'gpt_score' in data.keys()]

    # c. select if individual above 6
    test_data = saved_data
    for item in test_data:
        if 'gpt_score' not in item.keys():
            continue
        if item['gpt_score'] >= 6:
            temp_data.append(item)
    print(test_data[-1]['index'], max_idx)
    
    grouped = [[] for _ in range(max_idx)]
    for item in temp_data:
        idx = item['index']
        grouped[idx // num_samples].append(item)
    grouped = [item for item in grouped if len(item) > 0]
    
    # 2. within each group, make pairs where the chosens have higher score than the rejected ones.
    # TODO: find a way to balance the data generated from each group
    final_data = []
    for group in grouped:
        for item1 in group:
            for item2 in group:
                if item2['gpt_score'] > item1['gpt_score']:
                    info_dict = {
                        "description": item1['description'],
                        "prompt": item1['prompt'],
                        "chosen": item2['output'],
                        "rejected": item1['output']
                        }
                    final_data.append(info_dict)
                    # uncomment this break if you do not want too many data.
                    # break
            

    with open(save_path, 'w+') as f:
        json.dump(final_data, f, indent=4)