| import os |
| import json |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, CLIPImageProcessor |
|
|
| def rephrase_data_json(json_path): |
| output_json_dict = [] |
| with open(json_path) as f: |
| json_dict = json.load(f) |
| for i in range(len(json_dict)): |
| data_dict = json_dict[i] |
| if isinstance(data_dict['image'], list): |
| data_dict['image'] = data_dict['image'][0] |
| if data_dict['conversations'][0]['value'] == '' or \ |
| data_dict['conversations'][1]['value'] == '': |
| continue |
| output_json_dict.append(data_dict) |
| with open(json_path.replace('.json', '_rephrase.json'), 'w') as f: |
| json.dump(output_json_dict, f, indent= 4) |
| return |
| def test_llama3(): |
| model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", torch_dtype=torch.float16, low_cpu_mem_usage=True) |
| model = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336", low_cpu_mem_usage=True) |
| |
| def create_test_data(): |
| all_json = os.listdir('./data_json') |
| data_json = [] |
| for i in all_json: |
| if i.endswith('_test.json'): |
| with open(os.path.join('./data_json', i)) as f: |
| data_json.extend(json.load(f)[:200]) |
| with open('./data_json/instruct_sample_18430_0713_rephrase_test.json', 'w') as f: |
| json.dump(data_json, f, indent = 4) |
| |
| def create_eval_dataset(gt_data_dir, pred_data_json, output_dir): |
| with open(pred_data_json, 'r') as f: |
| pred_data = json.load(pred_data_json) |
|
|
| gt_data_list = {} |
| pred_data_list = {} |
| for i in pred_data: |
| task = i['task'] |
| if task not in pred_data_list.keys(): |
| pred_data_list[task] = [] |
| cur_data = i |
| cur_data['answer'] = cur_data['pred_output'] |
| cur_data.pop('source') |
| cur_data.pop('task') |
| cur_data.pop('gt_output') |
|
|
| if task not in gt_data_list.keys(): |
| with open(os.path.join(gt_data_dir, f'4dor_{task}_instruct_0711_test.json'), 'r') as f: |
| gt_data_list[task] = json.load(f)[:200] |
| for j in gt_data_list[task]: |
| if j['id'] == cur_data['id']: |
| cur_data['question'] = j['conversations'][0]['value'] |
| if cur_data['question'].start_with('<image>'): |
| cur_data['question'] = cur_data['question'][8:] |
|
|
| pred_data_list[task].append(cur_data) |
| output_file = os.path.join(output_dir, f'{task}_pred.json') |
| with open(output_file, 'w') as f: |
| json.dump(pred_data_list[task], f, indent=4) |
| |
| |
| def create_eval_dataset_from_output(pred_data_json, output_dir): |
| with open(pred_data_json) as f: |
| pred_data = json.load(f) |
| task_list = {} |
| for i in pred_data: |
| task = i['task'] |
| if task not in task_list.keys(): |
| task_list[task] = [] |
| cur_data = i |
| cur_data['answer'] = cur_data['pred_output'] |
| cur_data.pop('source') |
| cur_data.pop('task') |
| cur_data.pop('gt_output') |
| |
| task_list[task].append(cur_data) |
| |
| def debug(): |
| output_tasks = {} |
| output_list = [] |
| for j in range(8): |
| with open(f'./temp_{j}.json', 'r') as f: |
| temp_output = json.load(f) |
| for t in temp_output: |
| if t['task'] not in output_tasks.keys(): |
| output_tasks[t['task']] = [] |
| output_tasks[t['task']].append(t) |
| output_list.append(t) |
| os.remove(f'./temp_{j}.json') |
| |
| with open(os.path.join('.', f'preds.json'), 'w') as f: |
| json.dump(output_list, f, indent = 4) |
| for k in output_tasks.keys(): |
| with open(os.path.join('.' f'preds_{k}.json'), 'w') as f: |
| json.dump(output_tasks[k], f, indent = 4) |
|
|
| def count_correct(json_file = '/mnt1/lyc/llava_finetune/eval_output/pwiseg_count_eval_llama3_llava.json'): |
| with open(json_file) as f: |
| data = json.load(f) |
| num_corr = 0 |
| for d in data: |
| if type(d) == str: |
| continue |
| if d['answer'] == 'yes': |
| num_corr += 1 |
| print(f'Rate: {num_corr / 200}') |
| print(f'Num: {num_corr}') |
|
|
| def process_relationship_data_v0(relationship_json_file = '/mnt1/wjl/LLaVA/data/4dor_infos_0702.json'): |
| with open(relationship_json_file) as f: |
| relationship_json_data = json.load(f) |
| |
|
|
| if __name__ == '__main__': |
| |
| |
| |
| |
| |
| count_correct() |