File size: 3,289 Bytes
62dca4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | """
GSM8K benchmark evaluation script.
"""
import ast
import re
from typing import Any, Dict, List, Optional, Tuple
from sglang.utils import download_and_cache_file, read_jsonl
from .base import Benchmarker
from .registry import BENCHMARKS
from .utils import create_few_shot_sgl_function
INVALID = -9999999
def get_one_example(lines: List[Dict], i: int, include_answer: bool) -> str:
"""Format a single example."""
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
if include_answer:
ret += " " + lines[i]["answer"]
return ret
def get_few_shot_examples(lines: List[Dict], k: int) -> str:
"""Get few-shot examples as a string."""
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
def get_answer_value(answer_str: str) -> int:
"""Extract numeric answer from model output."""
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
@BENCHMARKS.register("gsm8k")
class GSM8KBenchmarker(Benchmarker):
"""GSM8K benchmark implementation."""
def __init__(self, num_samples: Optional[int] = None):
super().__init__(num_samples, None)
def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
"""Load and preprocess GSM8K dataset."""
# Read data
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
data_path = download_and_cache_file(url)
lines = list(read_jsonl(data_path))
# Construct prompts
few_shot_examples = get_few_shot_examples(lines, 5)
questions = []
labels = []
for i in range((len(lines))):
if self.num_samples is not None and i >= self.num_samples:
break
question_text = get_one_example(lines, i, False)
questions.append({"question": question_text})
labels.append(get_answer_value(lines[i]["answer"]))
# Store few_shot_examples for use in create_sgl_function
self.few_shot_examples = few_shot_examples
assert all(l != INVALID for l in labels), "Some labels are invalid"
return questions, labels
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]:
"""Extract numeric answer from model output."""
return get_answer_value(output)
def compute_accuracy(
self, predictions: List[Any], labels: List[Any]
) -> Optional[float]:
"""Compute accuracy for GSM8K by comparing numeric answers."""
if not labels or len(labels) == 0:
return None
correct = sum(1 for pred, label in zip(predictions, labels) if pred == label)
return correct / len(labels) if len(labels) > 0 else 0.0
def create_sgl_function(self):
"""Create SGL function for GSM8K with few-shot examples."""
return create_few_shot_sgl_function(
few_shot_examples=self.few_shot_examples,
function_name="few_shot_gsm8k",
answer_key="answer",
stop=["Question", "Assistant:", "<|separator|>"],
)
|