File size: 6,265 Bytes
91daf98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import json
import os
import time
import argparse
from openai_utils import ask_gpt_on_figure, ask_gpt
from llava_utils import ask_llm, ask_llm_on_figure, restart_model
from tqdm import tqdm
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--source-data-path", type=str, required=True)
parser.add_argument("--figure-path", type=str, required=True)
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--num-samples", type=int, required=True)
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--score-only", action="store_true", default=False)
parser.add_argument("--gpt", action="store_true", default=False)
args = parser.parse_args()
source_path = args.source_data_path
folder_path = args.figure_path
save_path = args.save_path
num_samples = args.num_samples
device=f'cuda:{args.gpu}'
if args.gpt:
func1, func2 = ask_gpt_on_figure, ask_gpt
model = None
processor = None
else:
func1, func2 = ask_llm_on_figure, ask_llm
model, processor = restart_model(device)
with open(source_path, 'r') as f:
test_data = json.load(f)
####### Stage 1 #######
# for model generations that are able to render pictures,
# ask gpt to rate the generation quality.
for data in tqdm(test_data):
file_id = str(data['index']).zfill(6)
file = None
for f in os.listdir(folder_path):
if f.startswith(file_id):
file = folder_path + f
data['figure_path'] = file
error_cnt = 0
while 1:
try:
data['gpt_label'] = func1(data, model, processor)
break
except Exception as e:
print(e)
if args.gpt:
time.sleep(3)
else:
if error_cnt == 5:
exit()
model, processor = restart_model(device)
error_cnt += 1
with open(save_path, 'w+') as f:
json.dump(test_data, f, indent=4)
with open(save_path, 'r') as f:
test_data = json.load(f)
####### Stage 2 #######
# clean up the dataset to summarize the generation quality estimation to a numerical score, and
# remove the failed ones, i.e. the generations that cannot render
for data in tqdm(test_data):
if "gpt_label" in data.keys():
error_cnt = 0
while 1:
try:
score = func2(data, model, processor)
print(score)
break
except Exception as e:
print(e)
if args.gpt:
time.sleep(3)
else:
if error_cnt == 5:
exit()
model, processor = restart_model(device)
error_cnt += 1
try:
data['gpt_score'] = int(score)
except:
print(f'ERROR: {score}')
pass
saved_data = [data for data in test_data if 'gpt_score' in data.keys()]
with open(save_path, 'w+') as f:
json.dump(saved_data, f, indent=4)
if args.score_only:
exit()
####### Stage 3 #######
# 1. group up the scored generations by their description: we do not compare
# generation results that come from different origin prompts
temp_data = []
max_idx = test_data[-1]['index']
sample_size = max_idx // num_samples + 1
# a. select if any above 6
# for i in range(sample_size):
# next_sample = test_data[i*num_samples:(i+1)*num_samples]
# next_sample = [item for item in next_sample if 'gpt_score' in item.keys()]
# above_score = [item['gpt_score'] >= 6 for item in next_sample]
# if any(above_score):
# temp_data.extend(next_sample)
# temp_data = [data for data in temp_data if 'gpt_score' in data.keys()]
# b. select if avg above 6
# for i in range(sample_size):
# next_sample = test_data[i*num_samples:(i+1)*num_samples]
# next_sample = [item for item in next_sample if 'gpt_score' in item.keys()]
# if len(next_sample) == 0:
# continue
# scores = sum(item['gpt_score'] for item in next_sample) / len(next_sample)
# if scores >= 6:
# temp_data.extend(next_sample)
# temp_data = [data for data in temp_data if 'gpt_score' in data.keys()]
# c. select if individual above 6
test_data = saved_data
for item in test_data:
if 'gpt_score' not in item.keys():
continue
if item['gpt_score'] >= 6:
temp_data.append(item)
print(test_data[-1]['index'], max_idx)
grouped = [[] for _ in range(max_idx)]
for item in temp_data:
idx = item['index']
grouped[idx // num_samples].append(item)
grouped = [item for item in grouped if len(item) > 0]
# 2. within each group, make pairs where the chosens have higher score than the rejected ones.
# TODO: find a way to balance the data generated from each group
final_data = []
for group in grouped:
for item1 in group:
for item2 in group:
if item2['gpt_score'] > item1['gpt_score']:
info_dict = {
"description": item1['description'],
"prompt": item1['prompt'],
"chosen": item2['output'],
"rejected": item1['output']
}
final_data.append(info_dict)
# uncomment this break if you do not want too many data.
# break
with open(save_path, 'w+') as f:
json.dump(final_data, f, indent=4) |