File size: 3,289 Bytes
62dca4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
GSM8K benchmark evaluation script.
"""

import ast
import re
from typing import Any, Dict, List, Optional, Tuple

from sglang.utils import download_and_cache_file, read_jsonl

from .base import Benchmarker
from .registry import BENCHMARKS
from .utils import create_few_shot_sgl_function

INVALID = -9999999


def get_one_example(lines: List[Dict], i: int, include_answer: bool) -> str:
    """Format a single example."""
    ret = "Question: " + lines[i]["question"] + "\nAnswer:"
    if include_answer:
        ret += " " + lines[i]["answer"]
    return ret


def get_few_shot_examples(lines: List[Dict], k: int) -> str:
    """Get few-shot examples as a string."""
    ret = ""
    for i in range(k):
        ret += get_one_example(lines, i, True) + "\n\n"
    return ret


def get_answer_value(answer_str: str) -> int:
    """Extract numeric answer from model output."""
    answer_str = answer_str.replace(",", "")
    numbers = re.findall(r"\d+", answer_str)
    if len(numbers) < 1:
        return INVALID
    try:
        return ast.literal_eval(numbers[-1])
    except SyntaxError:
        return INVALID


@BENCHMARKS.register("gsm8k")
class GSM8KBenchmarker(Benchmarker):
    """GSM8K benchmark implementation."""

    def __init__(self, num_samples: Optional[int] = None):
        super().__init__(num_samples, None)

    def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
        """Load and preprocess GSM8K dataset."""
        # Read data
        url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
        data_path = download_and_cache_file(url)
        lines = list(read_jsonl(data_path))

        # Construct prompts
        few_shot_examples = get_few_shot_examples(lines, 5)

        questions = []
        labels = []
        for i in range((len(lines))):
            if self.num_samples is not None and i >= self.num_samples:
                break

            question_text = get_one_example(lines, i, False)
            questions.append({"question": question_text})
            labels.append(get_answer_value(lines[i]["answer"]))

        # Store few_shot_examples for use in create_sgl_function
        self.few_shot_examples = few_shot_examples

        assert all(l != INVALID for l in labels), "Some labels are invalid"
        return questions, labels

    def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]:
        """Extract numeric answer from model output."""
        return get_answer_value(output)

    def compute_accuracy(
        self, predictions: List[Any], labels: List[Any]
    ) -> Optional[float]:
        """Compute accuracy for GSM8K by comparing numeric answers."""
        if not labels or len(labels) == 0:
            return None
        correct = sum(1 for pred, label in zip(predictions, labels) if pred == label)
        return correct / len(labels) if len(labels) > 0 else 0.0

    def create_sgl_function(self):
        """Create SGL function for GSM8K with few-shot examples."""
        return create_few_shot_sgl_function(
            few_shot_examples=self.few_shot_examples,
            function_name="few_shot_gsm8k",
            answer_key="answer",
            stop=["Question", "Assistant:", "<|separator|>"],
        )