File size: 4,968 Bytes
1efcb3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import numpy as np
import json
import argparse
import re
import torch
from tqdm import tqdm
from utils.cal_metric_vsibench import calculate_average_scores_vsibench
from utils.cal_metric_sparbench import calculate_average_scores_sparbench
def extract_answer(text):
pattern = r'<answer>\s*(.*?)\s*</answer>'
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(1).strip()
return text
def normalize_number(num_str):
try:
num_str = num_str.replace(',', '')
return float(num_str)
except Exception as e:
return None
def mean_relative_accuracy(pred, target, start=0.5, end=0.95, interval=0.05):
if not torch.is_tensor(pred):
pred = torch.tensor(pred, dtype=torch.float32)
if not torch.is_tensor(target):
target = torch.tensor(target, dtype=torch.float32)
epsilon = 1e-8
rel_error = torch.abs(pred - target) / (torch.abs(target) + epsilon)
thresholds = torch.arange(start, end + interval/2, interval, dtype=torch.float32)
conditions = rel_error <= (1 - thresholds)
mra = conditions.float().mean()
return mra.item()
def compute_vci_score(output_ans, gt_ans):
ACTION_PAIRS = {
"move_right_left": ("move_right", "move_left"),
"move_up_down": ("move_up", "move_down"),
"move_forward_backward": ("move_forward", "move_backward"),
"rotate_right_left": ("rotate_right", "rotate_left"),
"rotate_up_down": ("rotate_up", "rotate_down")
}
try:
answer_dict = parse_instruction(output_ans)
gt_dict = parse_instruction(gt_ans)
answer_list = []
gt_list = []
for action_pair, (pos, neg) in ACTION_PAIRS.items():
net_pred = answer_dict.get(pos, 0) - answer_dict.get(neg, 0)
net_gt = gt_dict.get(pos, 0) - gt_dict.get(neg, 0)
answer_list.append(net_pred)
gt_list.append(net_gt)
mra_list = [
mean_relative_accuracy(answer, gt)
for gt, answer in zip(gt_list, answer_list)
]
return np.mean(mra_list)
except Exception as e:
print(f"Error in VCI score calculation: {e}, output: {output_ans}")
return 0.0
def parse_instruction(instruction):
return {k: float(v) for k, v in [item.split(":") for item in instruction.split(",")]}
def reward_fn(model_output, gt_ans, question_type):
output_ans = extract_answer(model_output)
gt_ans = extract_answer(gt_ans)
if question_type == "multiple choice":
return 1.0 if output_ans.strip()[0].lower() == gt_ans.strip()[0].lower() else 0.0
elif question_type == "numerical":
gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
out_has_decimal = ("." in output_ans) or ("," in output_ans)
if gt_has_decimal != out_has_decimal:
return 0.0
gt_number = normalize_number(gt_ans)
out_number = normalize_number(output_ans)
if gt_number is None or out_number is None:
return 0.0
return 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
elif question_type == "regression":
gt_number = normalize_number(gt_ans)
out_number = normalize_number(output_ans)
if gt_number is None or out_number is None:
return 0.0
mra = mean_relative_accuracy(out_number, gt_number)
return mra
elif question_type == "vci":
return compute_vci_score(output_ans, gt_ans)
else:
return 0.0
def score_fixed_answer(args):
scores = []
results = []
pred_answers = [json.loads(q) for q in open(args.result_file)]
print("Length: ", len(pred_answers))
for input_example in pred_answers:
pred_answer_text = input_example.get('model_output', "")
gt_answer_text = input_example.get('answer', "")
problem_type = input_example.get("problem_type", "")
result = input_example.copy()
score = reward_fn(pred_answer_text,gt_answer_text, problem_type)
result['score'] = score
results.append(result)
scores.append(score)
print('The avg score is: %f' % np.mean(scores))
with open(args.output_result, 'w', encoding='utf-8') as f:
for result in results:
f.write(json.dumps(result, ensure_ascii=False) + '\n')
if args.dataset == 'vsi_bench':
calculate_average_scores_vsibench(results)
elif args.dataset == 'spar_bench':
calculate_average_scores_sparbench(results)
print(f"Results have been saved to {args.output_result} in JSONL format.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str)
parser.add_argument("--question-file", type=str)
parser.add_argument('--output-result', type=str)
parser.add_argument('--result-file', type=str)
args = parser.parse_args()
score_fixed_answer(args)
|