ViLaSR / eval /evaluate.py
cornuHGF's picture
Upload folder using huggingface_hub
1efcb3c verified
import numpy as np
import json
import argparse
import re
import torch
from tqdm import tqdm
from utils.cal_metric_vsibench import calculate_average_scores_vsibench
from utils.cal_metric_sparbench import calculate_average_scores_sparbench
def extract_answer(text):
pattern = r'<answer>\s*(.*?)\s*</answer>'
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(1).strip()
return text
def normalize_number(num_str):
try:
num_str = num_str.replace(',', '')
return float(num_str)
except Exception as e:
return None
def mean_relative_accuracy(pred, target, start=0.5, end=0.95, interval=0.05):
if not torch.is_tensor(pred):
pred = torch.tensor(pred, dtype=torch.float32)
if not torch.is_tensor(target):
target = torch.tensor(target, dtype=torch.float32)
epsilon = 1e-8
rel_error = torch.abs(pred - target) / (torch.abs(target) + epsilon)
thresholds = torch.arange(start, end + interval/2, interval, dtype=torch.float32)
conditions = rel_error <= (1 - thresholds)
mra = conditions.float().mean()
return mra.item()
def compute_vci_score(output_ans, gt_ans):
ACTION_PAIRS = {
"move_right_left": ("move_right", "move_left"),
"move_up_down": ("move_up", "move_down"),
"move_forward_backward": ("move_forward", "move_backward"),
"rotate_right_left": ("rotate_right", "rotate_left"),
"rotate_up_down": ("rotate_up", "rotate_down")
}
try:
answer_dict = parse_instruction(output_ans)
gt_dict = parse_instruction(gt_ans)
answer_list = []
gt_list = []
for action_pair, (pos, neg) in ACTION_PAIRS.items():
net_pred = answer_dict.get(pos, 0) - answer_dict.get(neg, 0)
net_gt = gt_dict.get(pos, 0) - gt_dict.get(neg, 0)
answer_list.append(net_pred)
gt_list.append(net_gt)
mra_list = [
mean_relative_accuracy(answer, gt)
for gt, answer in zip(gt_list, answer_list)
]
return np.mean(mra_list)
except Exception as e:
print(f"Error in VCI score calculation: {e}, output: {output_ans}")
return 0.0
def parse_instruction(instruction):
return {k: float(v) for k, v in [item.split(":") for item in instruction.split(",")]}
def reward_fn(model_output, gt_ans, question_type):
output_ans = extract_answer(model_output)
gt_ans = extract_answer(gt_ans)
if question_type == "multiple choice":
return 1.0 if output_ans.strip()[0].lower() == gt_ans.strip()[0].lower() else 0.0
elif question_type == "numerical":
gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
out_has_decimal = ("." in output_ans) or ("," in output_ans)
if gt_has_decimal != out_has_decimal:
return 0.0
gt_number = normalize_number(gt_ans)
out_number = normalize_number(output_ans)
if gt_number is None or out_number is None:
return 0.0
return 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
elif question_type == "regression":
gt_number = normalize_number(gt_ans)
out_number = normalize_number(output_ans)
if gt_number is None or out_number is None:
return 0.0
mra = mean_relative_accuracy(out_number, gt_number)
return mra
elif question_type == "vci":
return compute_vci_score(output_ans, gt_ans)
else:
return 0.0
def score_fixed_answer(args):
scores = []
results = []
pred_answers = [json.loads(q) for q in open(args.result_file)]
print("Length: ", len(pred_answers))
for input_example in pred_answers:
pred_answer_text = input_example.get('model_output', "")
gt_answer_text = input_example.get('answer', "")
problem_type = input_example.get("problem_type", "")
result = input_example.copy()
score = reward_fn(pred_answer_text,gt_answer_text, problem_type)
result['score'] = score
results.append(result)
scores.append(score)
print('The avg score is: %f' % np.mean(scores))
with open(args.output_result, 'w', encoding='utf-8') as f:
for result in results:
f.write(json.dumps(result, ensure_ascii=False) + '\n')
if args.dataset == 'vsi_bench':
calculate_average_scores_vsibench(results)
elif args.dataset == 'spar_bench':
calculate_average_scores_sparbench(results)
print(f"Results have been saved to {args.output_result} in JSONL format.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str)
parser.add_argument("--question-file", type=str)
parser.add_argument('--output-result', type=str)
parser.add_argument('--result-file', type=str)
args = parser.parse_args()
score_fixed_answer(args)