JustinTX's picture
Add files using upload-large-folder tool
14c9c2b verified
import time
import multiprocessing
import concurrent.futures
from typing import Callable
import pandas as pd
from utils import remove_boxed, last_boxed_only_string, is_equiv
def agent_evaluation(
Agent,
query_llm: Callable,
year: int = 2024,
) -> tuple[float, float, int, int, pd.DataFrame]:
math_test_set = pd.read_csv("AIME_Dataset_1983_2025.csv")
math_test_set = math_test_set[math_test_set["Year"] == year]
agent = Agent(query_llm)
results = []
max_workers = min(30, multiprocessing.cpu_count())
print(f"Loaded AIME dataset with {len(math_test_set)} examples")
print(f"Running parallel evaluation with {max_workers} workers")
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_idx = {
executor.submit(process_example, i, example, agent, query_llm): i
for i, (_, example) in enumerate(math_test_set.iterrows())
}
total, correct_count, total_llm_calls, cost_total = 0, 0, 0, 0
for future in concurrent.futures.as_completed(future_to_idx):
idx = future_to_idx[future]
total += 1
try:
(
_idx,
problem,
response,
llm_answer,
true_answer,
correct,
cost,
num_llm_calls,
) = future.result()
results.append(
{
"id": idx,
"problem": problem,
"response": response,
"llm_answer": llm_answer,
"true_answer": true_answer,
"correct": correct,
"cost": cost,
"num_llm_calls": num_llm_calls,
}
)
except Exception as e:
print(f"Error processing example {idx}: {e}")
continue
cost_total += cost
if correct:
correct_count += 1
total_llm_calls += num_llm_calls
accuracy = (correct_count / total) * 100
log_message = (
f"Step: {total}, LLM answer: {llm_answer}, "
f"True answer: {true_answer}, "
f"Accuracy: {accuracy:.2f}%, "
f"Cost: {cost_total:.4f}, "
f"LLM calls: {total_llm_calls}, "
f"Avg LLM calls: {total_llm_calls / total}"
)
print(log_message)
if total > 0:
final_accuracy = (correct_count / total) * 100
if final_accuracy == 0:
raise ValueError("Final accuracy is 0. This should not happen.")
print(
f"Complete, final accuracy: {final_accuracy:.2f}%, Cost: {cost_total:.2f}"
)
print(f"Time taken: {time.time() - start_time:.2f} seconds")
time_per_example = (time.time() - start_time) / total
print(f"Time per example: {time_per_example:.2f} seconds")
df = pd.DataFrame(results)
else:
raise ValueError("No examples were processed.")
return final_accuracy, cost_total, total, total_llm_calls, df
def evaluate_math_correctness(response: str, solution: str) -> tuple[str, str, bool]:
"""Evaluates the correctness of the LLM's response for MATH-500."""
# true_answer_str = remove_boxed(last_boxed_only_string(solution))
true_answer_str = solution.strip()
llm_answer_str = remove_boxed(last_boxed_only_string(response))
if llm_answer_str is not None:
llm_answer_str = llm_answer_str.lstrip("0")
if llm_answer_str == "":
llm_answer_str = "0"
true_answer_str = str(solution)
true_answer = "" if true_answer_str is None else true_answer_str
llm_answer = "" if llm_answer_str is None else llm_answer_str
correct = is_equiv(llm_answer, true_answer)
return llm_answer, true_answer, correct
def evaluate_aime_correctness(
response: str, solution: str
) -> tuple[str, str, bool, bool]:
"""Evaluates the correctness of the LLM's response for AIME."""
llm_answer_str = remove_boxed(last_boxed_only_string(response))
if llm_answer_str is not None:
llm_answer_str = llm_answer_str.lstrip("0")
if llm_answer_str == "":
llm_answer_str = "0"
true_answer_str = str(solution)
true_answer = "" if true_answer_str is None else true_answer_str
llm_answer = "" if llm_answer_str is None else llm_answer_str
correct = is_equiv(llm_answer, true_answer)
out_error = len(llm_answer) != 3
return llm_answer, true_answer, correct, out_error
def process_example(idx, example, agent, query_llm):
# Reset call count for each example if using call-limited query_llm
if hasattr(query_llm, "reset_calls"):
query_llm.reset_calls()
problem = example["problem"].strip()
solution = example["answer"]
response, cost = agent.forward(problem)
llm_answer, true_answer, correct = evaluate_math_correctness(response, solution)
num_llm_calls = query_llm.get_call_count()
return (
idx,
problem,
response,
llm_answer,
true_answer,
correct,
cost,
num_llm_calls,
)