csuhan's picture
Upload folder using huggingface_hub
b0c0df0 verified
import re
from lmms_eval.filters.extraction import ExtendedRegexFilter
from lmms_eval.filters.transformation import MapFilter
def ai2d_doc_to_text(doc, lmms_eval_specific_kwargs=None):
question, choices = doc["question"], doc["options"]
len_choices = len(choices)
post_prompt = lmms_eval_specific_kwargs["post_prompt"]
pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
if lmms_eval_specific_kwargs["prompt_format"] == "mcq":
options = [chr(ord("A") + i) for i in range(len_choices)]
choices_str = "\n".join([f"{option}. {choice}" for option, choice in zip(options, choices)])
return f"{pre_prompt}{question}\n{choices_str}{post_prompt}"
elif lmms_eval_specific_kwargs["prompt_format"] == "qa":
options = "\n".join(choices)
return f"{pre_prompt}{question}{options}{post_prompt}"
elif lmms_eval_specific_kwargs["prompt_format"] == "mcq_xcomposer":
options = [chr(ord("A") + i) for i in range(len_choices)]
choices_str = " ".join([f"{option}. {choice}" for option, choice in zip(options, choices)])
return f"{pre_prompt}{question}\nContext: N/A\n{choices_str}{post_prompt}"
else:
raise ValueError(f"Unknown prompt format: {lmms_eval_specific_kwargs['prompt_format']}")
def ai2d_doc_to_visual(doc):
return [doc["image"].convert("RGB")]
def ai2d_doc_to_target(doc, model_specific_target_kwargs):
if model_specific_target_kwargs == "mcq":
len_choices = len(doc["options"])
options = [chr(ord("A") + i) for i in range(len_choices)]
return options[int(doc["answer"])]
elif model_specific_target_kwargs == "qa":
return doc["options"][int(doc["answer"])]
class MultiChoiceRegexFilter(ExtendedRegexFilter):
def __init__(self, *args, **kwargs):
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
group_select: Selects the (group_select)th match from the findall result.
ignore_case: Ignores the case during step 1 matching
ignore_punctuation: Remove the punctuation during step 1 matching
regexes_to_ignore: Remove these regexes during step 1 matching
"""
super().__init__(*args, **kwargs)
def apply(self, resps, docs):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)
filtered_resps = []
for r, doc in zip(resps, docs):
# Regex to directly extract the option letter from the model response
option_letter_regex = re.compile(r"^\s*([A-Z])\.")
# Process each response
filtered = []
for resp in r:
# Try to match the option letter at the start of the response
match = option_letter_regex.match(resp)
if match:
# If a match is found, append the matched letter
filtered.append(match.group(1))
else:
# If no match, return the original response
filtered.append(resp)
# Assuming we need the first response that matches or the original response
filtered_resps.append(filtered[0])
return filtered_resps