| | import argparse
|
| | import json
|
| | import os
|
| | import re
|
| | import random
|
| |
|
| |
|
| | def get_args():
|
| | parser = argparse.ArgumentParser()
|
| | parser.add_argument('--base-dir', type=str)
|
| | parser.add_argument('--result-file', type=str)
|
| | parser.add_argument('--output-file', type=str)
|
| | parser.add_argument('--output-result', type=str)
|
| | parser.add_argument('--split', type=str, default='test')
|
| | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
|
| | return parser.parse_args()
|
| |
|
| |
|
| | def convert_caps(results):
|
| | fakecaps = []
|
| | for result in results:
|
| | image_id = result['question_id']
|
| | caption = result['text']
|
| | fakecaps.append({"image_id": int(image_id), "caption": caption})
|
| | return fakecaps
|
| |
|
| |
|
| | def get_pred_idx(prediction, choices, options):
|
| | """
|
| | Get the index (e.g. 2) from the prediction (e.g. 'C')
|
| | """
|
| | if prediction in options[:len(choices)]:
|
| | return options.index(prediction)
|
| | else:
|
| | return -1
|
| | return random.choice(range(len(choices)))
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | args = get_args()
|
| |
|
| | base_dir = args.base_dir
|
| | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
|
| | problems = json.load(open(os.path.join(base_dir, "problems.json")))
|
| | predictions = [json.loads(line) for line in open(args.result_file)]
|
| | predictions = {pred['question_id']: pred for pred in predictions}
|
| | split_problems = {idx: problems[idx] for idx in split_indices}
|
| |
|
| | results = {'correct': [], 'incorrect': []}
|
| | sqa_results = {}
|
| | sqa_results['acc'] = None
|
| | sqa_results['correct'] = None
|
| | sqa_results['count'] = None
|
| | sqa_results['results'] = {}
|
| | sqa_results['outputs'] = {}
|
| |
|
| | for prob_id, prob in split_problems.items():
|
| | if prob_id not in predictions:
|
| | pred = {'text': 'FAILED', 'prompt': 'Unknown'}
|
| | pred_text = 'FAILED'
|
| | else:
|
| | pred = predictions[prob_id]
|
| | pred_text = pred['text']
|
| |
|
| | if pred_text in args.options:
|
| | answer = pred_text
|
| | elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
|
| | answer = pred_text[0]
|
| | else:
|
| | pattern = re.compile(r'The answer is ([A-Z]).')
|
| | res = pattern.findall(pred_text)
|
| | if len(res) == 1:
|
| | answer = res[0]
|
| | else:
|
| | answer = "FAILED"
|
| |
|
| | pred_idx = get_pred_idx(answer, prob['choices'], args.options)
|
| |
|
| | analysis = {
|
| | 'question_id': prob_id,
|
| | 'parsed_ans': answer,
|
| | 'ground_truth': args.options[prob['answer']],
|
| | 'question': pred['prompt'],
|
| | 'pred': pred_text,
|
| | 'is_multimodal': '<image>' in pred['prompt'],
|
| | }
|
| |
|
| | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options)
|
| | sqa_results['outputs'][prob_id] = pred_text
|
| |
|
| | if pred_idx == prob['answer']:
|
| | results['correct'].append(analysis)
|
| | else:
|
| | results['incorrect'].append(analysis)
|
| |
|
| | correct = len(results['correct'])
|
| | total = len(results['correct']) + len(results['incorrect'])
|
| |
|
| |
|
| | multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
|
| | multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
|
| | multimodal_total = multimodal_correct + multimodal_incorrect
|
| |
|
| |
|
| | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
|
| |
|
| | sqa_results['acc'] = correct / total * 100
|
| | sqa_results['correct'] = correct
|
| | sqa_results['count'] = total
|
| |
|
| | with open(args.output_file, 'w') as f:
|
| | json.dump(results, f, indent=2)
|
| | with open(args.output_result, 'w') as f:
|
| | json.dump(sqa_results, f, indent=2)
|
| |
|