| import numpy as np | |
| import json | |
| import argparse | |
| import re | |
| import torch | |
| from tqdm import tqdm | |
| from utils.cal_metric_vsibench import calculate_average_scores_vsibench | |
| from utils.cal_metric_sparbench import calculate_average_scores_sparbench | |
| def extract_answer(text): | |
| pattern = r'<answer>\s*(.*?)\s*</answer>' | |
| match = re.search(pattern, text, re.DOTALL) | |
| if match: | |
| return match.group(1).strip() | |
| return text | |
| def normalize_number(num_str): | |
| try: | |
| num_str = num_str.replace(',', '') | |
| return float(num_str) | |
| except Exception as e: | |
| return None | |
| def mean_relative_accuracy(pred, target, start=0.5, end=0.95, interval=0.05): | |
| if not torch.is_tensor(pred): | |
| pred = torch.tensor(pred, dtype=torch.float32) | |
| if not torch.is_tensor(target): | |
| target = torch.tensor(target, dtype=torch.float32) | |
| epsilon = 1e-8 | |
| rel_error = torch.abs(pred - target) / (torch.abs(target) + epsilon) | |
| thresholds = torch.arange(start, end + interval/2, interval, dtype=torch.float32) | |
| conditions = rel_error <= (1 - thresholds) | |
| mra = conditions.float().mean() | |
| return mra.item() | |
| def compute_vci_score(output_ans, gt_ans): | |
| ACTION_PAIRS = { | |
| "move_right_left": ("move_right", "move_left"), | |
| "move_up_down": ("move_up", "move_down"), | |
| "move_forward_backward": ("move_forward", "move_backward"), | |
| "rotate_right_left": ("rotate_right", "rotate_left"), | |
| "rotate_up_down": ("rotate_up", "rotate_down") | |
| } | |
| try: | |
| answer_dict = parse_instruction(output_ans) | |
| gt_dict = parse_instruction(gt_ans) | |
| answer_list = [] | |
| gt_list = [] | |
| for action_pair, (pos, neg) in ACTION_PAIRS.items(): | |
| net_pred = answer_dict.get(pos, 0) - answer_dict.get(neg, 0) | |
| net_gt = gt_dict.get(pos, 0) - gt_dict.get(neg, 0) | |
| answer_list.append(net_pred) | |
| gt_list.append(net_gt) | |
| mra_list = [ | |
| mean_relative_accuracy(answer, gt) | |
| for gt, answer in zip(gt_list, answer_list) | |
| ] | |
| return np.mean(mra_list) | |
| except Exception as e: | |
| print(f"Error in VCI score calculation: {e}, output: {output_ans}") | |
| return 0.0 | |
| def parse_instruction(instruction): | |
| return {k: float(v) for k, v in [item.split(":") for item in instruction.split(",")]} | |
| def reward_fn(model_output, gt_ans, question_type): | |
| output_ans = extract_answer(model_output) | |
| gt_ans = extract_answer(gt_ans) | |
| if question_type == "multiple choice": | |
| return 1.0 if output_ans.strip()[0].lower() == gt_ans.strip()[0].lower() else 0.0 | |
| elif question_type == "numerical": | |
| gt_has_decimal = ("." in gt_ans) or ("," in gt_ans) | |
| out_has_decimal = ("." in output_ans) or ("," in output_ans) | |
| if gt_has_decimal != out_has_decimal: | |
| return 0.0 | |
| gt_number = normalize_number(gt_ans) | |
| out_number = normalize_number(output_ans) | |
| if gt_number is None or out_number is None: | |
| return 0.0 | |
| return 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0 | |
| elif question_type == "regression": | |
| gt_number = normalize_number(gt_ans) | |
| out_number = normalize_number(output_ans) | |
| if gt_number is None or out_number is None: | |
| return 0.0 | |
| mra = mean_relative_accuracy(out_number, gt_number) | |
| return mra | |
| elif question_type == "vci": | |
| return compute_vci_score(output_ans, gt_ans) | |
| else: | |
| return 0.0 | |
| def score_fixed_answer(args): | |
| scores = [] | |
| results = [] | |
| pred_answers = [json.loads(q) for q in open(args.result_file)] | |
| print("Length: ", len(pred_answers)) | |
| for input_example in pred_answers: | |
| pred_answer_text = input_example.get('model_output', "") | |
| gt_answer_text = input_example.get('answer', "") | |
| problem_type = input_example.get("problem_type", "") | |
| result = input_example.copy() | |
| score = reward_fn(pred_answer_text,gt_answer_text, problem_type) | |
| result['score'] = score | |
| results.append(result) | |
| scores.append(score) | |
| print('The avg score is: %f' % np.mean(scores)) | |
| with open(args.output_result, 'w', encoding='utf-8') as f: | |
| for result in results: | |
| f.write(json.dumps(result, ensure_ascii=False) + '\n') | |
| if args.dataset == 'vsi_bench': | |
| calculate_average_scores_vsibench(results) | |
| elif args.dataset == 'spar_bench': | |
| calculate_average_scores_sparbench(results) | |
| print(f"Results have been saved to {args.output_result} in JSONL format.") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--dataset", type=str) | |
| parser.add_argument("--question-file", type=str) | |
| parser.add_argument('--output-result', type=str) | |
| parser.add_argument('--result-file', type=str) | |
| args = parser.parse_args() | |
| score_fixed_answer(args) | |