Spaces:
Sleeping
Sleeping
Paul Gavrikov
commited on
Commit
·
dfe5ccf
1
Parent(s):
b06fec2
updating eval
Browse files- app.py +91 -29
- config.py +2 -2
- eval_utils.py +0 -40
- ground_truth.json +0 -1
- ground_truth.secret +0 -0
- judge.py +473 -0
app.py
CHANGED
|
@@ -1,15 +1,57 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import json, os, time, uuid
|
| 3 |
-
from eval_utils import evaluate_submission, clean_submission
|
| 4 |
import pandas as pd
|
| 5 |
from rate_limiter import RateLimiter, RateLimitConfig
|
| 6 |
from config import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def login_check(profile: gr.OAuthProfile | None):
|
| 10 |
visible = profile is not None
|
| 11 |
welcome = (
|
| 12 |
-
f"##
|
| 13 |
if visible
|
| 14 |
else "🔒 Please sign in to submit."
|
| 15 |
)
|
|
@@ -22,7 +64,7 @@ def login_check(profile: gr.OAuthProfile | None):
|
|
| 22 |
gr.Markdown(value=quota_details, visible=visible),
|
| 23 |
gr.Textbox(visible=visible & SAVE_SUBMISSIONS),
|
| 24 |
gr.File(visible=visible),
|
| 25 |
-
gr.Checkbox(visible=visible),
|
| 26 |
gr.Checkbox(visible=visible & SAVE_SUBMISSIONS),
|
| 27 |
gr.Button(visible=visible),
|
| 28 |
)
|
|
@@ -53,7 +95,7 @@ def quoata_check(profile: gr.OAuthProfile | None):
|
|
| 53 |
def submit(
|
| 54 |
submission_id: str,
|
| 55 |
submission_file: str,
|
| 56 |
-
is_cleaning: bool,
|
| 57 |
is_private: bool,
|
| 58 |
profile: gr.OAuthProfile | None,
|
| 59 |
):
|
|
@@ -71,17 +113,18 @@ def submit(
|
|
| 71 |
prediction_json = json.load(file)
|
| 72 |
except json.JSONDecodeError:
|
| 73 |
raise gr.Error("❌ Submission file is invalid JSON.")
|
| 74 |
-
|
| 75 |
-
with open(GROUND_TRUTH, "rb") as file:
|
| 76 |
-
ground_truth_json = json.load(file)
|
| 77 |
-
|
| 78 |
try:
|
| 79 |
-
if is_cleaning:
|
| 80 |
-
|
| 81 |
-
score_dict = evaluate_submission(prediction_json, ground_truth_json)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
except Exception as e:
|
| 83 |
print(e)
|
| 84 |
-
raise gr.Error(f"❌ Invalid submission format.")
|
| 85 |
|
| 86 |
allowed, allowed_reason = limiter.is_allowed(username)
|
| 87 |
status = limiter.get_status(username)
|
|
@@ -117,15 +160,29 @@ def submit(
|
|
| 117 |
|
| 118 |
json.dump(data, open(os.path.join(SUB_DIR, submission_record + ".json"), "w"))
|
| 119 |
|
| 120 |
-
score_response = f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
return gr.Text(score_response, visible=True)
|
| 123 |
|
| 124 |
|
| 125 |
def get_leaderboard() -> pd.DataFrame | str:
|
| 126 |
-
df = pd.read_csv("leaderboard.csv")
|
| 127 |
df = df.sort_values(by="Total", ascending=False)
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
|
| 131 |
def get_quota(profile: gr.OAuthProfile | None = None):
|
|
@@ -137,18 +194,18 @@ with gr.Blocks() as app:
|
|
| 137 |
if SHOW_LEADERBOARD:
|
| 138 |
with gr.Tab("🏆 Public Leaderboard"):
|
| 139 |
leaderboard_heading_md = gr.Markdown(
|
| 140 |
-
|
| 141 |
)
|
| 142 |
leaderboard_table = gr.Dataframe(get_leaderboard())
|
| 143 |
leaderboard_footer_md = gr.Markdown(
|
| 144 |
-
|
| 145 |
)
|
| 146 |
if SHOW_EVAL_SERVER:
|
| 147 |
with gr.Tab("🚀 Evaluation"):
|
| 148 |
login_button = gr.LoginButton()
|
| 149 |
welcome_md = gr.Markdown("🔒 Please sign in to submit.")
|
| 150 |
welcome_details_md = gr.Markdown(
|
| 151 |
-
|
| 152 |
visible=False,
|
| 153 |
)
|
| 154 |
submission_file = gr.File(
|
|
@@ -157,17 +214,17 @@ with gr.Blocks() as app:
|
|
| 157 |
submission_id = gr.Textbox(
|
| 158 |
label="(Optional) Submission identifier", visible=False
|
| 159 |
)
|
| 160 |
-
clean_flag = gr.Checkbox(
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
)
|
| 165 |
private_flag = gr.Checkbox(
|
| 166 |
label="Do not save my submission", value=False, visible=False
|
| 167 |
)
|
| 168 |
quota_details = gr.Markdown(visible=False)
|
| 169 |
submit_btn = gr.Button("Submit", visible=False)
|
| 170 |
-
result = gr.
|
| 171 |
|
| 172 |
# Load login state → show/hide components
|
| 173 |
app.load(
|
|
@@ -179,7 +236,7 @@ with gr.Blocks() as app:
|
|
| 179 |
quota_details,
|
| 180 |
submission_id,
|
| 181 |
submission_file,
|
| 182 |
-
clean_flag,
|
| 183 |
private_flag,
|
| 184 |
submit_btn,
|
| 185 |
],
|
|
@@ -187,7 +244,9 @@ with gr.Blocks() as app:
|
|
| 187 |
|
| 188 |
futures = submit_btn.click(
|
| 189 |
fn=submit,
|
| 190 |
-
inputs=[submission_id, submission_file,
|
|
|
|
|
|
|
| 191 |
outputs=[result],
|
| 192 |
).then(quoata_check, outputs=[quota_details])
|
| 193 |
|
|
@@ -197,9 +256,9 @@ with gr.Blocks() as app:
|
|
| 197 |
outputs=[leaderboard_table],
|
| 198 |
)
|
| 199 |
|
| 200 |
-
copyright = gr.Markdown(
|
| 201 |
-
|
| 202 |
-
)
|
| 203 |
|
| 204 |
|
| 205 |
if __name__ == "__main__":
|
|
@@ -211,4 +270,7 @@ if __name__ == "__main__":
|
|
| 211 |
)
|
| 212 |
limiter = RateLimiter(config)
|
| 213 |
|
|
|
|
|
|
|
|
|
|
| 214 |
app.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import json, os, time, uuid
|
|
|
|
| 3 |
import pandas as pd
|
| 4 |
from rate_limiter import RateLimiter, RateLimitConfig
|
| 5 |
from config import *
|
| 6 |
+
from judge import build_question_index_from_json, judge
|
| 7 |
+
import pprint
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
SUBMISSION_TEXT = """Upload a json file with your predictions and click Submit. Your predictions should be a list of dictionaries, each containing an \"question_id\" field and a \"response\" field. For multiple choice questions, the \"response\" field should contain the predicted answer choice. For open-ended questions, the \"response\" field should contain the option letter (A-D). We will apply simple heuistics to clean the responses, but please ensure they are as accurate as possible.
|
| 11 |
+
|
| 12 |
+
Example:
|
| 13 |
+
```
|
| 14 |
+
[
|
| 15 |
+
{\"question_id\": \"28deb79e\", \"response\": \"A\"},
|
| 16 |
+
{\"question_id\": \"73cbabd7\", \"response\": \"C\"},
|
| 17 |
+
...
|
| 18 |
+
]
|
| 19 |
+
```
|
| 20 |
+
Your file:"""
|
| 21 |
+
|
| 22 |
+
INTRO_TEXT = """# Welcome to the VisualOverload Leaderboard!
|
| 23 |
+
Below you will find the public leaderboard for the [VisualOverload benchmark](https://huggingface.co/datasets/paulgavrikov/visualoverload), which evaluates models on their ability to understand and reason about complex visual scenes. We seperate by models and 'special' inference techniques (e.g., special prompts, ICL, CoT etc.) to better understand the source of their performance.
|
| 24 |
+
|
| 25 |
+
The leaderboard ranks models based on their overall accuracy across a six tasks (activity recognition, attribute recognition, counting, OCR, reasoning, and global scene recognition). We provide an aggregate score (Total) as well as individual scores on three distinct splits per difficulty (Easy, Medium, Hard), and each task."""
|
| 26 |
+
|
| 27 |
+
INTRO_DETAILS = "Please see the evaluation tab for evaluation and details on how to list your results."
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_ground_truth():
|
| 31 |
+
from cryptography.fernet import Fernet
|
| 32 |
+
|
| 33 |
+
key = os.getenv("SECRET_KEY")
|
| 34 |
+
|
| 35 |
+
cipher = Fernet(key)
|
| 36 |
+
with open("ground_truth.secret", "rb") as f:
|
| 37 |
+
json_data = json.loads(cipher.decrypt(f.read().decode()))
|
| 38 |
+
|
| 39 |
+
hardness_levels = json_data["splits"]
|
| 40 |
+
df_gt = pd.DataFrame.from_dict(json_data["benchmark"])
|
| 41 |
+
df_gt.question_id = df_gt.question_id.astype(str)
|
| 42 |
+
df_gt = df_gt.set_index("question_id")
|
| 43 |
+
|
| 44 |
+
for level, ids in hardness_levels.items():
|
| 45 |
+
ids = [str(i) for i in ids]
|
| 46 |
+
df_gt.loc[ids, "difficulty"] = level
|
| 47 |
+
|
| 48 |
+
return df_gt.reset_index()
|
| 49 |
|
| 50 |
|
| 51 |
def login_check(profile: gr.OAuthProfile | None):
|
| 52 |
visible = profile is not None
|
| 53 |
welcome = (
|
| 54 |
+
f"## Welcome to the evaluation server @{profile.username} 👋"
|
| 55 |
if visible
|
| 56 |
else "🔒 Please sign in to submit."
|
| 57 |
)
|
|
|
|
| 64 |
gr.Markdown(value=quota_details, visible=visible),
|
| 65 |
gr.Textbox(visible=visible & SAVE_SUBMISSIONS),
|
| 66 |
gr.File(visible=visible),
|
| 67 |
+
# gr.Checkbox(visible=visible),
|
| 68 |
gr.Checkbox(visible=visible & SAVE_SUBMISSIONS),
|
| 69 |
gr.Button(visible=visible),
|
| 70 |
)
|
|
|
|
| 95 |
def submit(
|
| 96 |
submission_id: str,
|
| 97 |
submission_file: str,
|
| 98 |
+
# is_cleaning: bool,
|
| 99 |
is_private: bool,
|
| 100 |
profile: gr.OAuthProfile | None,
|
| 101 |
):
|
|
|
|
| 113 |
prediction_json = json.load(file)
|
| 114 |
except json.JSONDecodeError:
|
| 115 |
raise gr.Error("❌ Submission file is invalid JSON.")
|
| 116 |
+
|
|
|
|
|
|
|
|
|
|
| 117 |
try:
|
| 118 |
+
# if is_cleaning:
|
| 119 |
+
# prediction_json = clean_submission(prediction_json)
|
| 120 |
+
# score_dict = evaluate_submission(prediction_json, ground_truth_json)
|
| 121 |
+
|
| 122 |
+
_, score_dict = judge(prediction_json, question_index)
|
| 123 |
+
score_dict = {k: round(v * 100, 1) for k, v in score_dict.items() if k.startswith("accuracy/")}
|
| 124 |
+
|
| 125 |
except Exception as e:
|
| 126 |
print(e)
|
| 127 |
+
raise gr.Error(f"❌ Invalid submission format. Check logs for details.")
|
| 128 |
|
| 129 |
allowed, allowed_reason = limiter.is_allowed(username)
|
| 130 |
status = limiter.get_status(username)
|
|
|
|
| 160 |
|
| 161 |
json.dump(data, open(os.path.join(SUB_DIR, submission_record + ".json"), "w"))
|
| 162 |
|
| 163 |
+
score_response = f"""
|
| 164 |
+
Your submission has been evaluated!
|
| 165 |
+
|
| 166 |
+
```
|
| 167 |
+
{pprint.pformat(score_dict, indent=4, sort_dicts=False)}
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
If you want your submission to appear on the public leaderboard, please follow the instructions to open a ticket at [https://github.com/paulgavrikov/visualoverload/issues](https://github.com/paulgavrikov/visualoverload/issues).
|
| 171 |
+
"""
|
| 172 |
|
| 173 |
return gr.Text(score_response, visible=True)
|
| 174 |
|
| 175 |
|
| 176 |
def get_leaderboard() -> pd.DataFrame | str:
|
| 177 |
+
df = pd.read_csv("leaderboard.csv").set_index(["Model", "Special Inference"])
|
| 178 |
df = df.sort_values(by="Total", ascending=False)
|
| 179 |
+
|
| 180 |
+
df = df.reset_index()
|
| 181 |
+
|
| 182 |
+
float_cols = df.select_dtypes(include=["float"]).columns
|
| 183 |
+
styler = df.style.format('{:.1f}', subset=float_cols)
|
| 184 |
+
|
| 185 |
+
return styler
|
| 186 |
|
| 187 |
|
| 188 |
def get_quota(profile: gr.OAuthProfile | None = None):
|
|
|
|
| 194 |
if SHOW_LEADERBOARD:
|
| 195 |
with gr.Tab("🏆 Public Leaderboard"):
|
| 196 |
leaderboard_heading_md = gr.Markdown(
|
| 197 |
+
INTRO_TEXT
|
| 198 |
)
|
| 199 |
leaderboard_table = gr.Dataframe(get_leaderboard())
|
| 200 |
leaderboard_footer_md = gr.Markdown(
|
| 201 |
+
INTRO_DETAILS
|
| 202 |
)
|
| 203 |
if SHOW_EVAL_SERVER:
|
| 204 |
with gr.Tab("🚀 Evaluation"):
|
| 205 |
login_button = gr.LoginButton()
|
| 206 |
welcome_md = gr.Markdown("🔒 Please sign in to submit.")
|
| 207 |
welcome_details_md = gr.Markdown(
|
| 208 |
+
SUBMISSION_TEXT,
|
| 209 |
visible=False,
|
| 210 |
)
|
| 211 |
submission_file = gr.File(
|
|
|
|
| 214 |
submission_id = gr.Textbox(
|
| 215 |
label="(Optional) Submission identifier", visible=False
|
| 216 |
)
|
| 217 |
+
# clean_flag = gr.Checkbox(
|
| 218 |
+
# label="Attempt to clean my submission (Recommended for raw responses)",
|
| 219 |
+
# value=True,
|
| 220 |
+
# visible=False,
|
| 221 |
+
# )
|
| 222 |
private_flag = gr.Checkbox(
|
| 223 |
label="Do not save my submission", value=False, visible=False
|
| 224 |
)
|
| 225 |
quota_details = gr.Markdown(visible=False)
|
| 226 |
submit_btn = gr.Button("Submit", visible=False)
|
| 227 |
+
result = gr.Markdown(label="✅ Submission processed", visible=False)
|
| 228 |
|
| 229 |
# Load login state → show/hide components
|
| 230 |
app.load(
|
|
|
|
| 236 |
quota_details,
|
| 237 |
submission_id,
|
| 238 |
submission_file,
|
| 239 |
+
# clean_flag,
|
| 240 |
private_flag,
|
| 241 |
submit_btn,
|
| 242 |
],
|
|
|
|
| 244 |
|
| 245 |
futures = submit_btn.click(
|
| 246 |
fn=submit,
|
| 247 |
+
inputs=[submission_id, submission_file,
|
| 248 |
+
# clean_flag,
|
| 249 |
+
private_flag],
|
| 250 |
outputs=[result],
|
| 251 |
).then(quoata_check, outputs=[quota_details])
|
| 252 |
|
|
|
|
| 256 |
outputs=[leaderboard_table],
|
| 257 |
)
|
| 258 |
|
| 259 |
+
# copyright = gr.Markdown(
|
| 260 |
+
# "Based on the [gradio-eval-server-template](https://github.com/paulgavrikov/gradio-eval-server-template) by Paul Gavrikov."
|
| 261 |
+
# )
|
| 262 |
|
| 263 |
|
| 264 |
if __name__ == "__main__":
|
|
|
|
| 270 |
)
|
| 271 |
limiter = RateLimiter(config)
|
| 272 |
|
| 273 |
+
df_ground_truth = load_ground_truth()
|
| 274 |
+
question_index = build_question_index_from_json(df_ground_truth.to_dict(orient="records")) # TODO: this should be precomputed once and reused
|
| 275 |
+
|
| 276 |
app.launch()
|
config.py
CHANGED
|
@@ -5,9 +5,9 @@ SUB_DIR = "./submissions"
|
|
| 5 |
# Wait time for each submission in seconds, or 0 to disable minimum wait time
|
| 6 |
RATE_LIMIT_MIN_INT_SEC = 10
|
| 7 |
# Maximum total number of submissions per user, or 0 for no limit
|
| 8 |
-
MAX_TOTAL_SUBMISSIONS_PER_USER =
|
| 9 |
# Maxmimum number of submissions per user per day, or 0 for no limit
|
| 10 |
-
MAX_SUBMISSIONS_PER_USER_PER_DAY =
|
| 11 |
# Save submissions
|
| 12 |
SAVE_SUBMISSIONS = False
|
| 13 |
# Save submissions
|
|
|
|
| 5 |
# Wait time for each submission in seconds, or 0 to disable minimum wait time
|
| 6 |
RATE_LIMIT_MIN_INT_SEC = 10
|
| 7 |
# Maximum total number of submissions per user, or 0 for no limit
|
| 8 |
+
MAX_TOTAL_SUBMISSIONS_PER_USER = 0
|
| 9 |
# Maxmimum number of submissions per user per day, or 0 for no limit
|
| 10 |
+
MAX_SUBMISSIONS_PER_USER_PER_DAY = 5
|
| 11 |
# Save submissions
|
| 12 |
SAVE_SUBMISSIONS = False
|
| 13 |
# Save submissions
|
eval_utils.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
from collections import defaultdict
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
class AverageMeter(object):
|
| 5 |
-
"""Computes and stores the average and current value"""
|
| 6 |
-
|
| 7 |
-
def __init__(self):
|
| 8 |
-
self.reset()
|
| 9 |
-
|
| 10 |
-
def reset(self):
|
| 11 |
-
self.val = 0
|
| 12 |
-
self.avg = 0
|
| 13 |
-
self.sum = 0
|
| 14 |
-
self.count = 0
|
| 15 |
-
|
| 16 |
-
def update(self, val, n=1):
|
| 17 |
-
self.val = val
|
| 18 |
-
self.sum += val * n
|
| 19 |
-
self.count += n
|
| 20 |
-
self.avg = self.sum / self.count
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def clean_submission(preds):
|
| 24 |
-
cleaned_preds = preds
|
| 25 |
-
|
| 26 |
-
# TODO: Implement your cleaning logic here
|
| 27 |
-
|
| 28 |
-
return cleaned_preds
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def evaluate_submission(preds, ground_truth) -> dict:
|
| 32 |
-
# TODO: Implement your evaluation logic
|
| 33 |
-
|
| 34 |
-
accuracy_meter = defaultdict(AverageMeter)
|
| 35 |
-
|
| 36 |
-
accuracy_meter["partition1"].update(0.43)
|
| 37 |
-
accuracy_meter["partition2"].update(0.42)
|
| 38 |
-
accuracy_meter["partition3"].update(0.99)
|
| 39 |
-
|
| 40 |
-
return dict([(f"accuracy/{k}", 100 * v.avg) for k, v in accuracy_meter.items()])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ground_truth.json
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{}
|
|
|
|
|
|
ground_truth.secret
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
judge.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from glob import glob
|
| 4 |
+
import math
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
import re
|
| 7 |
+
import pandas as pd
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def make_option_labels(options):
|
| 11 |
+
option_labels = []
|
| 12 |
+
for i, option in enumerate(options):
|
| 13 |
+
option_labels.append(f"{chr(65 + i)}. {option.strip()}")
|
| 14 |
+
return option_labels
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class AverageMeter(object):
|
| 18 |
+
"""
|
| 19 |
+
Computes and stores the average and current value.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
self.reset()
|
| 24 |
+
|
| 25 |
+
def reset(self):
|
| 26 |
+
self.val = 0
|
| 27 |
+
self.avg = 0
|
| 28 |
+
self.sum = 0
|
| 29 |
+
self.count = 0
|
| 30 |
+
|
| 31 |
+
def update(self, val, n=1):
|
| 32 |
+
self.val = val
|
| 33 |
+
self.sum += val * n
|
| 34 |
+
self.count += n
|
| 35 |
+
self.avg = self.sum / self.count
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def clean_counting_answer(answer):
|
| 39 |
+
|
| 40 |
+
if answer is None or answer.strip() == "":
|
| 41 |
+
return 0
|
| 42 |
+
|
| 43 |
+
answer_matches = re.findall(r"<answer>(.*?)</answer>", answer, flags=re.DOTALL)
|
| 44 |
+
if len(answer_matches) > 0:
|
| 45 |
+
answer = answer_matches[0]
|
| 46 |
+
|
| 47 |
+
clean_answer = answer.strip().lower()
|
| 48 |
+
if clean_answer.endswith("."):
|
| 49 |
+
clean_answer = clean_answer[:-1].strip()
|
| 50 |
+
|
| 51 |
+
if clean_answer in words_to_int:
|
| 52 |
+
clean_answer = words_to_int[clean_answer]
|
| 53 |
+
else:
|
| 54 |
+
try:
|
| 55 |
+
clean_answer = int(round(float(clean_answer)))
|
| 56 |
+
except ValueError:
|
| 57 |
+
|
| 58 |
+
# or pick the LAST number in the string
|
| 59 |
+
match = re.search(r'\d+(?:\.\d+)?(?=(?![\s\S]*\d))', clean_answer)
|
| 60 |
+
if match:
|
| 61 |
+
clean_answer = int(round(float(match.group(0))))
|
| 62 |
+
else:
|
| 63 |
+
matched = False
|
| 64 |
+
# or pick the FIRST spelled out number in the string
|
| 65 |
+
for word, number in words_to_int.items():
|
| 66 |
+
if word in clean_answer:
|
| 67 |
+
clean_answer = number
|
| 68 |
+
matched = True
|
| 69 |
+
break
|
| 70 |
+
if not matched:
|
| 71 |
+
print(f"WARNING: Unable to convert answer '{answer}' to int.")
|
| 72 |
+
clean_answer = 0
|
| 73 |
+
return clean_answer
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def clean_multiple_choice_answer(answer, options):
|
| 77 |
+
if answer is None:
|
| 78 |
+
return ""
|
| 79 |
+
|
| 80 |
+
answer_matches = re.findall(r"<answer>(.*?)</answer>", answer, flags=re.DOTALL)
|
| 81 |
+
if len(answer_matches) > 0:
|
| 82 |
+
answer = answer_matches[0]
|
| 83 |
+
|
| 84 |
+
clean_answer = answer.strip()
|
| 85 |
+
|
| 86 |
+
if answer.startswith("Answer:"):
|
| 87 |
+
clean_answer = clean_answer.replace("Answer:", "").strip()
|
| 88 |
+
if answer.startswith("The answer is "):
|
| 89 |
+
clean_answer = clean_answer.replace("The answer is ", "").strip()
|
| 90 |
+
if answer.startswith("The best answer is "):
|
| 91 |
+
clean_answer = clean_answer.replace("The best answer is ", "").strip()
|
| 92 |
+
if answer.endswith("."):
|
| 93 |
+
clean_answer = clean_answer[:-1].strip()
|
| 94 |
+
|
| 95 |
+
if len(clean_answer) > 1:
|
| 96 |
+
# If the answer is longer than one character, we assume it may contain the full label, e.g. "A. option text"
|
| 97 |
+
for option in options:
|
| 98 |
+
if option in clean_answer:
|
| 99 |
+
clean_answer = option[0]
|
| 100 |
+
break
|
| 101 |
+
|
| 102 |
+
if len(clean_answer) > 1:
|
| 103 |
+
# If the answer is still longer than one character, we assume it is a short label, e.g. "A" or "B"
|
| 104 |
+
for option in options:
|
| 105 |
+
if option[0] in clean_answer:
|
| 106 |
+
clean_answer = option[0]
|
| 107 |
+
break
|
| 108 |
+
|
| 109 |
+
return clean_answer
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def clean_ocr_answer(answer):
|
| 113 |
+
answer_matches = re.findall(r"<answer>(.*?)</answer>", answer, flags=re.DOTALL)
|
| 114 |
+
if len(answer_matches) > 0:
|
| 115 |
+
answer = answer_matches[0]
|
| 116 |
+
|
| 117 |
+
clean_answer = answer.strip()
|
| 118 |
+
|
| 119 |
+
clean_answer = extract_text_from_quotes(clean_answer)
|
| 120 |
+
clean_answer = clean_text(clean_answer)
|
| 121 |
+
|
| 122 |
+
return clean_answer
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def validate_choice_answer(answer, benchmark_truth):
|
| 126 |
+
|
| 127 |
+
all_options = make_option_labels(benchmark_truth["options"])
|
| 128 |
+
|
| 129 |
+
# clean the reponse
|
| 130 |
+
clean_answer = clean_multiple_choice_answer(answer["response"], all_options)
|
| 131 |
+
|
| 132 |
+
# map the truth to option key
|
| 133 |
+
correct_option_index = benchmark_truth["options"].index(benchmark_truth["ground_truth"])
|
| 134 |
+
correct_option_enum = all_options[correct_option_index][0]
|
| 135 |
+
|
| 136 |
+
return (clean_answer == correct_option_enum)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
words_to_int = {
|
| 140 |
+
"zero": 0,
|
| 141 |
+
"one": 1,
|
| 142 |
+
"two": 2,
|
| 143 |
+
"three": 3,
|
| 144 |
+
"four": 4,
|
| 145 |
+
"five": 5,
|
| 146 |
+
"six": 6,
|
| 147 |
+
"seven": 7,
|
| 148 |
+
"eight": 8,
|
| 149 |
+
"nine": 9,
|
| 150 |
+
"ten": 10,
|
| 151 |
+
"eleven": 11,
|
| 152 |
+
"twelve": 12,
|
| 153 |
+
"thirteen": 13,
|
| 154 |
+
"fourteen": 14,
|
| 155 |
+
"fifteen": 15,
|
| 156 |
+
"sixteen": 16,
|
| 157 |
+
"seventeen": 17,
|
| 158 |
+
"eighteen": 18,
|
| 159 |
+
"nineteen": 19,
|
| 160 |
+
"twenty": 20,
|
| 161 |
+
"thirty": 30,
|
| 162 |
+
"forty": 40,
|
| 163 |
+
"fifty": 50,
|
| 164 |
+
"sixty": 60,
|
| 165 |
+
"seventy": 70,
|
| 166 |
+
"eighty": 80,
|
| 167 |
+
"ninety": 90,
|
| 168 |
+
"hundred": 100,
|
| 169 |
+
"thousand": 1000,
|
| 170 |
+
"million": 1000000,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
def clean_text(s):
|
| 174 |
+
|
| 175 |
+
replace_characters = {
|
| 176 |
+
"ä": "a",
|
| 177 |
+
"á": "a",
|
| 178 |
+
"à": "a",
|
| 179 |
+
"â": "a",
|
| 180 |
+
"ã": "a",
|
| 181 |
+
"å": "a",
|
| 182 |
+
"ā": "a",
|
| 183 |
+
"ö": "o",
|
| 184 |
+
"ó": "o",
|
| 185 |
+
"ò": "o",
|
| 186 |
+
"ô": "o",
|
| 187 |
+
"õ": "o",
|
| 188 |
+
"ō": "o",
|
| 189 |
+
"ü": "u",
|
| 190 |
+
"ú": "u",
|
| 191 |
+
"ù": "u",
|
| 192 |
+
"û": "u",
|
| 193 |
+
"ū": "u",
|
| 194 |
+
"é": "e",
|
| 195 |
+
"ĕ": "e",
|
| 196 |
+
"ė": "e",
|
| 197 |
+
"ę": "e",
|
| 198 |
+
"ě": "e",
|
| 199 |
+
"ç": "c",
|
| 200 |
+
"ć": "c",
|
| 201 |
+
"č": "c",
|
| 202 |
+
"ñ": "n",
|
| 203 |
+
"ń": "n",
|
| 204 |
+
"ń": "n",
|
| 205 |
+
"ł": "l",
|
| 206 |
+
"ś": "s",
|
| 207 |
+
"š": "s",
|
| 208 |
+
"ź": "z",
|
| 209 |
+
"ż": "z",
|
| 210 |
+
"ý": "y",
|
| 211 |
+
"ŷ": "y",
|
| 212 |
+
"ÿ": "y",
|
| 213 |
+
"œ": "oe",
|
| 214 |
+
"æ": "ae",
|
| 215 |
+
|
| 216 |
+
"v": "u"
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
s = s.strip().lower()
|
| 220 |
+
|
| 221 |
+
# replace special characters
|
| 222 |
+
for char, replacement in replace_characters.items():
|
| 223 |
+
s = s.replace(char, replacement)
|
| 224 |
+
|
| 225 |
+
s = s.replace("\t", "").replace("\n", "").replace(".", "").replace("&", "").replace(";", "")\
|
| 226 |
+
.replace(",", "").replace("-", "").replace("–", "").replace("’", "'").replace(":", "").replace("·", " ")\
|
| 227 |
+
.replace("'", "").replace("“", "").replace("”", "").replace('"', "").replace("•", " ")\
|
| 228 |
+
.replace(" ", " ")
|
| 229 |
+
return s
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def extract_text_from_quotes(s):
|
| 233 |
+
pattern = r"'([^']*?)\"|\"([^\"]*?)\"|`([^`]*?)`|“([^”]*?)”|‘([^’]*?)’"
|
| 234 |
+
matches = re.findall(pattern, s)
|
| 235 |
+
|
| 236 |
+
if matches:
|
| 237 |
+
# Extract and return the single letter
|
| 238 |
+
matched_group = [match for group in matches for match in group if match][0]
|
| 239 |
+
return matched_group
|
| 240 |
+
else:
|
| 241 |
+
# Return the original string if no match is found
|
| 242 |
+
return s
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def main(args):
|
| 246 |
+
try:
|
| 247 |
+
from pprint import pprint as print
|
| 248 |
+
except ImportError:
|
| 249 |
+
pass
|
| 250 |
+
|
| 251 |
+
question_index = build_question_index_from_file(args.benchmark_file)
|
| 252 |
+
|
| 253 |
+
for file in glob(args.results_file):
|
| 254 |
+
print(f"Judging {file}")
|
| 255 |
+
_, scores = judge_file(file, question_index)
|
| 256 |
+
print(scores)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def build_question_index_from_file(benchmark_file):
|
| 261 |
+
with open(benchmark_file, "r") as f:
|
| 262 |
+
json_data = json.load(f)
|
| 263 |
+
|
| 264 |
+
hardness_levels = json_data["splits"]
|
| 265 |
+
df_gt = pd.DataFrame.from_dict(json_data["benchmark"]).set_index("question_id")
|
| 266 |
+
|
| 267 |
+
for level, ids in hardness_levels.items():
|
| 268 |
+
df_gt.loc[ids, "difficulty"] = level
|
| 269 |
+
|
| 270 |
+
return build_question_index_from_json(df_gt.reset_index().to_dict(orient="records"))
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def build_question_index_from_json(benchmark_data):
|
| 274 |
+
question_index = {}
|
| 275 |
+
for question in benchmark_data:
|
| 276 |
+
question_id = question["question_id"]
|
| 277 |
+
question_index[question_id] = question
|
| 278 |
+
return question_index
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def judge_file(results_file, question_index):
|
| 282 |
+
with open(results_file, "r") as f:
|
| 283 |
+
results_data = json.load(f)
|
| 284 |
+
for answer in results_data:
|
| 285 |
+
answer["question_id"] = str(answer["question_id"]) # ensure question_id is string
|
| 286 |
+
|
| 287 |
+
return judge(results_data, question_index)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def judge(results_data, question_index):
|
| 291 |
+
answer_index = {}
|
| 292 |
+
for answer in results_data:
|
| 293 |
+
answer_index[answer["question_id"]] = answer
|
| 294 |
+
|
| 295 |
+
non_answered_questions = set(question_index.keys()) - set(answer_index.keys())
|
| 296 |
+
excessive_answers = set(answer_index.keys()) - set(question_index.keys())
|
| 297 |
+
|
| 298 |
+
if len(non_answered_questions) > 0:
|
| 299 |
+
print("WARNING: Some question IDs in benchmark data are not found in results file:")
|
| 300 |
+
print(non_answered_questions)
|
| 301 |
+
results_data = results_data + [{"question_id": qid, "response": ""} for qid in non_answered_questions]
|
| 302 |
+
else:
|
| 303 |
+
print("All question IDs in benchmark data are found in results file.")
|
| 304 |
+
|
| 305 |
+
if len(excessive_answers) > 0:
|
| 306 |
+
print("WARNING: Some question IDs in results file are not found in benchmark data:")
|
| 307 |
+
print(excessive_answers)
|
| 308 |
+
print("These questions will be ignored in the evaluation.")
|
| 309 |
+
results_data = [answer for answer in results_data if answer["question_id"] in question_index]
|
| 310 |
+
else:
|
| 311 |
+
print("All question IDs in results file are found in benchmark data.")
|
| 312 |
+
|
| 313 |
+
print()
|
| 314 |
+
|
| 315 |
+
accuracy_meters = defaultdict(AverageMeter)
|
| 316 |
+
|
| 317 |
+
# process counting data and compute accuracy and MAE
|
| 318 |
+
correct = 0
|
| 319 |
+
total = 0
|
| 320 |
+
mae = 0
|
| 321 |
+
mse = 0
|
| 322 |
+
|
| 323 |
+
for answer in results_data:
|
| 324 |
+
if answer["question_id"] not in question_index:
|
| 325 |
+
continue
|
| 326 |
+
|
| 327 |
+
benchmark_truth = question_index[answer["question_id"]]
|
| 328 |
+
if benchmark_truth["question_type"] != "counting":
|
| 329 |
+
continue
|
| 330 |
+
|
| 331 |
+
clean_answer = clean_counting_answer(answer["response"])
|
| 332 |
+
|
| 333 |
+
gt = benchmark_truth["ground_truth"]
|
| 334 |
+
difference = abs(clean_answer - gt)
|
| 335 |
+
mae += difference
|
| 336 |
+
mse += difference ** 2
|
| 337 |
+
|
| 338 |
+
is_correct = (clean_answer == gt)
|
| 339 |
+
correct_count = 1 if is_correct else 0
|
| 340 |
+
answer["judge/correct"] = is_correct
|
| 341 |
+
answer["judge/extracted_answer"] = clean_answer
|
| 342 |
+
|
| 343 |
+
correct += correct_count
|
| 344 |
+
total += 1
|
| 345 |
+
|
| 346 |
+
accuracy_meters[benchmark_truth["source_file"]].update(correct_count)
|
| 347 |
+
|
| 348 |
+
# process OCR data and compute accuracy and ESD
|
| 349 |
+
|
| 350 |
+
correct = 0
|
| 351 |
+
total = 0
|
| 352 |
+
|
| 353 |
+
for answer in results_data:
|
| 354 |
+
if answer["question_id"] not in question_index:
|
| 355 |
+
continue
|
| 356 |
+
|
| 357 |
+
benchmark_truth = question_index[answer["question_id"]]
|
| 358 |
+
if benchmark_truth["question_type"] != "ocr":
|
| 359 |
+
continue
|
| 360 |
+
|
| 361 |
+
# clean the reponse
|
| 362 |
+
|
| 363 |
+
clean_answer = clean_ocr_answer(answer["response"])
|
| 364 |
+
clean_gt = clean_text(benchmark_truth["ground_truth"])
|
| 365 |
+
|
| 366 |
+
answer["judge/extracted_answer"] = clean_answer
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
is_correct = clean_answer == clean_gt
|
| 370 |
+
|
| 371 |
+
correct_count = 1 if is_correct else 0
|
| 372 |
+
|
| 373 |
+
answer["judge/correct"] = is_correct
|
| 374 |
+
|
| 375 |
+
correct += correct_count
|
| 376 |
+
total += 1
|
| 377 |
+
|
| 378 |
+
accuracy_meters[benchmark_truth["source_file"]].update(correct_count)
|
| 379 |
+
|
| 380 |
+
# process multiple choice data without binary options and compute accuracy
|
| 381 |
+
|
| 382 |
+
correct = 0
|
| 383 |
+
total = 0
|
| 384 |
+
|
| 385 |
+
for answer in results_data:
|
| 386 |
+
if answer["question_id"] not in question_index:
|
| 387 |
+
continue
|
| 388 |
+
|
| 389 |
+
benchmark_truth = question_index[answer["question_id"]]
|
| 390 |
+
if benchmark_truth["question_type"] != "choice" or len(benchmark_truth["options"]) != 4:
|
| 391 |
+
continue
|
| 392 |
+
|
| 393 |
+
is_correct = validate_choice_answer(answer, benchmark_truth)
|
| 394 |
+
correct_count = 1 if is_correct else 0
|
| 395 |
+
answer["judge/correct"] = is_correct
|
| 396 |
+
|
| 397 |
+
correct += correct_count
|
| 398 |
+
total += 1
|
| 399 |
+
|
| 400 |
+
accuracy_meters[benchmark_truth["source_file"]].update(correct_count)
|
| 401 |
+
|
| 402 |
+
# process binary choice data and compute accuracy
|
| 403 |
+
correct = 0
|
| 404 |
+
total = 0
|
| 405 |
+
|
| 406 |
+
for answer in results_data:
|
| 407 |
+
if answer["question_id"] not in question_index:
|
| 408 |
+
continue
|
| 409 |
+
|
| 410 |
+
benchmark_truth = question_index[answer["question_id"]]
|
| 411 |
+
if benchmark_truth["question_type"] != "choice" or len(benchmark_truth["options"]) != 2:
|
| 412 |
+
continue
|
| 413 |
+
|
| 414 |
+
is_correct = validate_choice_answer(answer, benchmark_truth)
|
| 415 |
+
correct_count = 1 if is_correct else 0
|
| 416 |
+
answer["judge/correct"] = is_correct
|
| 417 |
+
|
| 418 |
+
correct += correct_count
|
| 419 |
+
total += 1
|
| 420 |
+
|
| 421 |
+
# process binary choice data with correction and compute accuracy
|
| 422 |
+
correct = 0
|
| 423 |
+
total = 0
|
| 424 |
+
|
| 425 |
+
opposite_error_pairs = []
|
| 426 |
+
|
| 427 |
+
for answer in results_data:
|
| 428 |
+
if answer["question_id"] not in question_index:
|
| 429 |
+
continue
|
| 430 |
+
|
| 431 |
+
benchmark_truth = question_index[answer["question_id"]]
|
| 432 |
+
|
| 433 |
+
# skip if the question is not a binary choice or does not have an opposite
|
| 434 |
+
if benchmark_truth["question_type"] != "choice" or len(benchmark_truth["options"]) != 2 or pd.isna(benchmark_truth["opposite_of"]):
|
| 435 |
+
continue
|
| 436 |
+
|
| 437 |
+
is_correct = validate_choice_answer(answer, benchmark_truth)
|
| 438 |
+
if benchmark_truth["opposite_of"] in answer_index:
|
| 439 |
+
is_opposite_correct = validate_choice_answer(answer_index[benchmark_truth["opposite_of"]], question_index[benchmark_truth["opposite_of"]])
|
| 440 |
+
answer_index[benchmark_truth["opposite_of"]]["judge/correct"] = is_opposite_correct
|
| 441 |
+
else:
|
| 442 |
+
is_opposite_correct = False
|
| 443 |
+
|
| 444 |
+
answer["judge/correct"] = is_correct
|
| 445 |
+
|
| 446 |
+
if is_correct and is_opposite_correct:
|
| 447 |
+
correct += 1
|
| 448 |
+
accuracy_meters[benchmark_truth["source_file"]].update(1)
|
| 449 |
+
else:
|
| 450 |
+
opposite_error_pairs.append((answer["question_id"], benchmark_truth["opposite_of"]))
|
| 451 |
+
accuracy_meters[benchmark_truth["source_file"]].update(0)
|
| 452 |
+
total += 1
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
df_preds = pd.DataFrame(results_data).set_index("question_id")
|
| 456 |
+
df_gt = pd.DataFrame.from_dict(question_index).T.set_index("question_id")
|
| 457 |
+
df = df_preds.join(df_gt)
|
| 458 |
+
|
| 459 |
+
scores = {
|
| 460 |
+
"is_complete": len(non_answered_questions) == 0,
|
| 461 |
+
"is_excessive": len(excessive_answers) > 0,
|
| 462 |
+
**dict([
|
| 463 |
+
("accuracy/" + k.replace(".csv", ""), (v.sum/v.count)) for k, v in accuracy_meters.items()
|
| 464 |
+
]),
|
| 465 |
+
"accuracy/easy": df.query("difficulty == 'easy'")["judge/correct"].mean(),
|
| 466 |
+
"accuracy/medium": df.query("difficulty == 'medium'")["judge/correct"].mean(),
|
| 467 |
+
"accuracy/hard": df.query("difficulty == 'hard'")["judge/correct"].mean(),
|
| 468 |
+
"accuracy/total": df["judge/correct"].mean(),
|
| 469 |
+
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
return results_data, scores
|
| 473 |
+
|