llava_finetune / utils.py
lyclyc52's picture
Update: integrate llama3 into finetuning code
157f5b2
import os
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, CLIPImageProcessor
def rephrase_data_json(json_path):
output_json_dict = []
with open(json_path) as f:
json_dict = json.load(f)
for i in range(len(json_dict)):
data_dict = json_dict[i]
if isinstance(data_dict['image'], list):
data_dict['image'] = data_dict['image'][0]
if data_dict['conversations'][0]['value'] == '' or \
data_dict['conversations'][1]['value'] == '':
continue
output_json_dict.append(data_dict)
with open(json_path.replace('.json', '_rephrase.json'), 'w') as f:
json.dump(output_json_dict, f, indent= 4)
return
def test_llama3():
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336", low_cpu_mem_usage=True)
def create_test_data():
all_json = os.listdir('./data_json')
data_json = []
for i in all_json:
if i.endswith('_test.json'):
with open(os.path.join('./data_json', i)) as f:
data_json.extend(json.load(f)[:200])
with open('./data_json/instruct_sample_18430_0713_rephrase_test.json', 'w') as f:
json.dump(data_json, f, indent = 4)
def create_eval_dataset(gt_data_dir, pred_data_json, output_dir):
with open(pred_data_json, 'r') as f:
pred_data = json.load(pred_data_json)
gt_data_list = {}
pred_data_list = {}
for i in pred_data:
task = i['task']
if task not in pred_data_list.keys():
pred_data_list[task] = []
cur_data = i
cur_data['answer'] = cur_data['pred_output']
cur_data.pop('source')
cur_data.pop('task')
cur_data.pop('gt_output')
if task not in gt_data_list.keys():
with open(os.path.join(gt_data_dir, f'4dor_{task}_instruct_0711_test.json'), 'r') as f:
gt_data_list[task] = json.load(f)[:200]
for j in gt_data_list[task]:
if j['id'] == cur_data['id']:
cur_data['question'] = j['conversations'][0]['value']
if cur_data['question'].start_with('<image>'):
cur_data['question'] = cur_data['question'][8:]
pred_data_list[task].append(cur_data)
output_file = os.path.join(output_dir, f'{task}_pred.json')
with open(output_file, 'w') as f:
json.dump(pred_data_list[task], f, indent=4)
def create_eval_dataset_from_output(pred_data_json, output_dir):
with open(pred_data_json) as f:
pred_data = json.load(f)
task_list = {}
for i in pred_data:
task = i['task']
if task not in task_list.keys():
task_list[task] = []
cur_data = i
cur_data['answer'] = cur_data['pred_output']
cur_data.pop('source')
cur_data.pop('task')
cur_data.pop('gt_output')
task_list[task].append(cur_data)
def debug():
output_tasks = {}
output_list = []
for j in range(8):
with open(f'./temp_{j}.json', 'r') as f:
temp_output = json.load(f)
for t in temp_output:
if t['task'] not in output_tasks.keys():
output_tasks[t['task']] = []
output_tasks[t['task']].append(t)
output_list.append(t)
os.remove(f'./temp_{j}.json')
with open(os.path.join('.', f'preds.json'), 'w') as f:
json.dump(output_list, f, indent = 4)
for k in output_tasks.keys():
with open(os.path.join('.' f'preds_{k}.json'), 'w') as f:
json.dump(output_tasks[k], f, indent = 4)
def count_correct(json_file = '/mnt1/lyc/llava_finetune/eval_output/pwiseg_count_eval_llama3_llava.json'):
with open(json_file) as f:
data = json.load(f)
num_corr = 0
for d in data:
if type(d) == str:
continue
if d['answer'] == 'yes':
num_corr += 1
print(f'Rate: {num_corr / 200}')
print(f'Num: {num_corr}')
def process_relationship_data_v0(relationship_json_file = '/mnt1/wjl/LLaVA/data/4dor_infos_0702.json'):
with open(relationship_json_file) as f:
relationship_json_data = json.load(f)
if __name__ == '__main__':
# json_path = '/mnt1/lyc/LLaVA-NeXT/instruct_sample_18430_0713.json'
# rephrase_data_json(json_path)
# test_llama3()
# create_test_data()
# debug()
count_correct()