File size: 5,374 Bytes
14c9c2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import time
import multiprocessing
import concurrent.futures
from typing import Callable
import pandas as pd
from utils import remove_boxed, last_boxed_only_string, is_equiv


def agent_evaluation(
    Agent,
    query_llm: Callable,
    year: int = 2024,
) -> tuple[float, float, int, int, pd.DataFrame]:
    math_test_set = pd.read_csv("AIME_Dataset_1983_2025.csv")
    math_test_set = math_test_set[math_test_set["Year"] == year]
    agent = Agent(query_llm)

    results = []
    max_workers = min(30, multiprocessing.cpu_count())
    print(f"Loaded AIME dataset with {len(math_test_set)} examples")
    print(f"Running parallel evaluation with {max_workers} workers")
    start_time = time.time()

    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_idx = {
            executor.submit(process_example, i, example, agent, query_llm): i
            for i, (_, example) in enumerate(math_test_set.iterrows())
        }
        total, correct_count, total_llm_calls, cost_total = 0, 0, 0, 0
        for future in concurrent.futures.as_completed(future_to_idx):
            idx = future_to_idx[future]
            total += 1
            try:
                (
                    _idx,
                    problem,
                    response,
                    llm_answer,
                    true_answer,
                    correct,
                    cost,
                    num_llm_calls,
                ) = future.result()
                results.append(
                    {
                        "id": idx,
                        "problem": problem,
                        "response": response,
                        "llm_answer": llm_answer,
                        "true_answer": true_answer,
                        "correct": correct,
                        "cost": cost,
                        "num_llm_calls": num_llm_calls,
                    }
                )
            except Exception as e:
                print(f"Error processing example {idx}: {e}")
                continue

            cost_total += cost
            if correct:
                correct_count += 1
            total_llm_calls += num_llm_calls
            accuracy = (correct_count / total) * 100
            log_message = (
                f"Step: {total}, LLM answer: {llm_answer}, "
                f"True answer: {true_answer}, "
                f"Accuracy: {accuracy:.2f}%, "
                f"Cost: {cost_total:.4f}, "
                f"LLM calls: {total_llm_calls}, "
                f"Avg LLM calls: {total_llm_calls / total}"
            )
            print(log_message)

    if total > 0:
        final_accuracy = (correct_count / total) * 100
        if final_accuracy == 0:
            raise ValueError("Final accuracy is 0. This should not happen.")
        print(
            f"Complete, final accuracy: {final_accuracy:.2f}%, Cost: {cost_total:.2f}"
        )
        print(f"Time taken: {time.time() - start_time:.2f} seconds")
        time_per_example = (time.time() - start_time) / total
        print(f"Time per example: {time_per_example:.2f} seconds")

        df = pd.DataFrame(results)
    else:
        raise ValueError("No examples were processed.")
    return final_accuracy, cost_total, total, total_llm_calls, df


def evaluate_math_correctness(response: str, solution: str) -> tuple[str, str, bool]:
    """Evaluates the correctness of the LLM's response for MATH-500."""
    # true_answer_str = remove_boxed(last_boxed_only_string(solution))
    true_answer_str = solution.strip()
    llm_answer_str = remove_boxed(last_boxed_only_string(response))
    if llm_answer_str is not None:
        llm_answer_str = llm_answer_str.lstrip("0")
        if llm_answer_str == "":
            llm_answer_str = "0"
    true_answer_str = str(solution)

    true_answer = "" if true_answer_str is None else true_answer_str
    llm_answer = "" if llm_answer_str is None else llm_answer_str

    correct = is_equiv(llm_answer, true_answer)
    return llm_answer, true_answer, correct


def evaluate_aime_correctness(
    response: str, solution: str
) -> tuple[str, str, bool, bool]:
    """Evaluates the correctness of the LLM's response for AIME."""
    llm_answer_str = remove_boxed(last_boxed_only_string(response))
    if llm_answer_str is not None:
        llm_answer_str = llm_answer_str.lstrip("0")
        if llm_answer_str == "":
            llm_answer_str = "0"
    true_answer_str = str(solution)

    true_answer = "" if true_answer_str is None else true_answer_str
    llm_answer = "" if llm_answer_str is None else llm_answer_str

    correct = is_equiv(llm_answer, true_answer)
    out_error = len(llm_answer) != 3
    return llm_answer, true_answer, correct, out_error


def process_example(idx, example, agent, query_llm):
    # Reset call count for each example if using call-limited query_llm
    if hasattr(query_llm, "reset_calls"):
        query_llm.reset_calls()

    problem = example["problem"].strip()
    solution = example["answer"]
    response, cost = agent.forward(problem)
    llm_answer, true_answer, correct = evaluate_math_correctness(response, solution)
    num_llm_calls = query_llm.get_call_count()
    return (
        idx,
        problem,
        response,
        llm_answer,
        true_answer,
        correct,
        cost,
        num_llm_calls,
    )