File size: 6,736 Bytes
62dca4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | """
MMStar benchmark evaluation script.
"""
import os
import re
import shutil
from typing import Any, Dict, List, Optional, Tuple
from datasets import load_dataset
from .base import Benchmarker
from .registry import BENCHMARKS
from .utils import create_image_sgl_function
def extract_mmstar_answer(
output: str, options: Optional[List[str]] = None
) -> Optional[str]:
"""Extract answer from MMStar model output.
MMStar questions typically have multiple choice options (A, B, C, D, etc.)
"""
output_upper = output.strip().upper()
# Try to find answer choice (A, B, C, D, etc.)
# Direct match for single letter
match = re.search(r"\b([A-Z])\b", output_upper)
if match:
letter = match.group(1)
if options and len(options) > 0:
# Validate that the letter is within valid range
max_option = chr(64 + len(options)) # 'A' + (len-1)
if "A" <= letter <= max_option:
return letter
else:
# Assume A-D are valid
if "A" <= letter <= "D":
return letter
# Try to find answer in parentheses or brackets
for pattern in [
r"\(([A-Z])\)",
r"\[([A-Z])\]",
r"答案[::]\s*([A-Z])",
r"Answer[::]\s*([A-Z])",
r"选择[::]\s*([A-Z])",
]:
match = re.search(pattern, output_upper)
if match:
letter = match.group(1)
if options and len(options) > 0:
max_option = chr(64 + len(options))
if "A" <= letter <= max_option:
return letter
elif "A" <= letter <= "D":
return letter
return None
@BENCHMARKS.register("mmstar")
class MMStarBenchmarker(Benchmarker):
"""MMStar benchmark implementation."""
def __init__(self, num_samples: Optional[int] = None):
super().__init__(num_samples, None)
"""Initialize benchmark and set up cache directory."""
self.cache_dir = None
self.options_list = [] # Store options for each question
def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
"""Load and preprocess MMStar dataset."""
self.cache_dir = os.path.join(".cache", "mmstar_specforge")
image_dir = os.path.join(self.cache_dir, "images")
os.makedirs(self.cache_dir, exist_ok=True)
os.makedirs(image_dir, exist_ok=True)
print(f"Created temporary image directory: {self.cache_dir}")
dataset = load_dataset("Lin-Chen/MMStar")["val"]
questions = []
labels = []
self.options_list = []
for idx, q in enumerate(dataset):
if self.num_samples is not None and idx >= self.num_samples:
break
image = q["image"]
image_path = os.path.join(self.cache_dir, q["meta_info"]["image_path"])
image.convert("RGB").save(image_path, "JPEG")
# Extract question and options
question_full = q["question"]
if "Options:" in question_full:
question_text, options_text = question_full.split("Options:", 1)
question_text = question_text.strip()
# Parse options (typically A. option1 B. option2 etc.)
options = []
for line in options_text.strip().split("\n"):
line = line.strip()
if line and re.match(r"^[A-Z]\.", line):
option_text = re.sub(r"^[A-Z]\.\s*", "", line).strip()
options.append(option_text)
self.options_list.append(options)
else:
question_text = question_full.strip()
self.options_list.append([])
item = {
"image_path": image_path,
"question": question_text,
}
questions.append(item)
# Extract ground truth answer
answer = None
if "answer" in q:
answer = str(q["answer"]).strip().upper()
elif "correct_answer" in q:
answer = str(q["correct_answer"]).strip().upper()
elif "ground_truth" in q:
answer = str(q["ground_truth"]).strip().upper()
# Validate answer is a valid option letter
if answer and len(answer) == 1 and "A" <= answer <= "Z":
if self.options_list[-1]:
max_option = chr(64 + len(self.options_list[-1]))
if answer <= max_option:
labels.append(answer)
else:
labels.append(None)
else:
labels.append(answer)
else:
labels.append(None)
return questions, labels
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
"""Extract answer from model output."""
# Use the options for the current question if available
# Note: We can't easily get the question index here, so we'll use a simpler approach
return extract_mmstar_answer(output)
def compute_accuracy(
self, predictions: List[Any], labels: List[Any]
) -> Optional[float]:
"""Compute accuracy for MMStar by comparing answer choices."""
if not labels or len(labels) == 0:
return None
if all(label is None for label in labels):
return None
correct = 0
valid_count = 0
for pred, label in zip(predictions, labels):
if label is not None:
valid_count += 1
if pred is not None:
# Normalize to uppercase for comparison
pred_normalized = str(pred).strip().upper()
label_normalized = str(label).strip().upper()
if pred_normalized == label_normalized:
correct += 1
return correct / valid_count if valid_count > 0 else 0.0
def create_sgl_function(self):
"""Create SGL function for MMStar (image-based Q&A)."""
return create_image_sgl_function(
function_name="get_mmstar_answer",
answer_key="answer",
max_tokens=self.get_max_new_tokens(),
)
def run(self, *args, **kwargs):
"""Run benchmark and clean up cache directory."""
try:
return super().run(*args, **kwargs)
finally:
# Clean up cache directory
if self.cache_dir and os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir)
print(f"Deleted temporary directory: {self.cache_dir}")
|