|
|
|
|
|
"""
|
|
|
Script to calculate accuracy of a trained GRPO model on arithmetic countdown problems.
|
|
|
|
|
|
This script loads a CSV file with problem data, performs inference using the trained model,
|
|
|
and calculates the accuracy by comparing predicted answers with correct answers.
|
|
|
"""
|
|
|
import sys, os
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
project_root = os.path.abspath(os.path.join(current_dir, "../.."))
|
|
|
sys.path.append(project_root)
|
|
|
|
|
|
import argparse
|
|
|
import logging
|
|
|
import re
|
|
|
import sys
|
|
|
from pathlib import Path
|
|
|
|
|
|
import pandas as pd
|
|
|
import torch
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
sys.path.append(str(Path(__file__).parent.parent))
|
|
|
|
|
|
from src.utils.inference import GRPOModelInference
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
|
)
|
|
|
logger = logging.getLogger("calculate_accuracy")
|
|
|
|
|
|
|
|
|
def load_csv_data(csv_path: str) -> pd.DataFrame:
|
|
|
"""
|
|
|
Load CSV data with the expected format.
|
|
|
|
|
|
Expected columns: id, problem_description, correct_answer, num1, num2, num3, num4
|
|
|
|
|
|
Args:
|
|
|
csv_path: Path to the CSV file
|
|
|
|
|
|
Returns:
|
|
|
DataFrame with the loaded data
|
|
|
"""
|
|
|
df = pd.read_csv(csv_path)
|
|
|
|
|
|
|
|
|
required_columns = [
|
|
|
"id",
|
|
|
"problem_description",
|
|
|
"correct_answer",
|
|
|
"num1",
|
|
|
"num2",
|
|
|
"num3",
|
|
|
"num4",
|
|
|
]
|
|
|
missing_columns = [col for col in required_columns if col not in df.columns]
|
|
|
|
|
|
if missing_columns:
|
|
|
raise ValueError(f"Missing required columns: {missing_columns}")
|
|
|
|
|
|
logger.info(f"Loaded {len(df)} problems from {csv_path}")
|
|
|
return df
|
|
|
|
|
|
|
|
|
def safe_eval_expression(expression: str) -> tuple[float | None, bool]:
|
|
|
"""
|
|
|
Safely evaluate an arithmetic expression.
|
|
|
|
|
|
Args:
|
|
|
expression: The arithmetic expression to evaluate
|
|
|
|
|
|
Returns:
|
|
|
Tuple of (result, is_valid)
|
|
|
"""
|
|
|
if not expression or not expression.strip():
|
|
|
return None, False
|
|
|
|
|
|
|
|
|
normalized = expression.replace("x", "*").replace("X", "*")
|
|
|
|
|
|
|
|
|
allowed_chars = set("0123456789+-*/.() ")
|
|
|
if not all(c in allowed_chars for c in normalized):
|
|
|
return None, False
|
|
|
|
|
|
try:
|
|
|
result = eval(normalized)
|
|
|
return result, True
|
|
|
except (SyntaxError, ValueError, ZeroDivisionError, NameError):
|
|
|
return None, False
|
|
|
|
|
|
|
|
|
def check_numbers_usage(expression: str, required_numbers: list[int]) -> bool:
|
|
|
"""
|
|
|
Check if the expression uses exactly the required numbers.
|
|
|
|
|
|
Args:
|
|
|
expression: The arithmetic expression to check
|
|
|
required_numbers: List of numbers that should be used exactly once each
|
|
|
|
|
|
Returns:
|
|
|
True if expression uses all required numbers exactly once, False otherwise
|
|
|
"""
|
|
|
if not expression or not expression.strip():
|
|
|
return False
|
|
|
|
|
|
|
|
|
numbers_in_expression = re.findall(r"\b\d+\b", expression)
|
|
|
|
|
|
|
|
|
try:
|
|
|
numbers_in_expression = [int(num) for num in numbers_in_expression]
|
|
|
except ValueError:
|
|
|
return False
|
|
|
|
|
|
|
|
|
required_sorted = sorted(required_numbers)
|
|
|
found_sorted = sorted(numbers_in_expression)
|
|
|
|
|
|
return required_sorted == found_sorted
|
|
|
|
|
|
|
|
|
def evaluate_prediction(
|
|
|
predicted_answer: str, correct_answer: int, nums: list[int]
|
|
|
) -> dict:
|
|
|
"""
|
|
|
Evaluate a single prediction against the correct answer.
|
|
|
|
|
|
Args:
|
|
|
predicted_answer: The model's predicted arithmetic expression
|
|
|
correct_answer: The correct integer result
|
|
|
nums: List of four numbers used in the problem
|
|
|
|
|
|
Returns:
|
|
|
Dictionary with evaluation results
|
|
|
"""
|
|
|
result = {
|
|
|
"predicted_answer": predicted_answer,
|
|
|
"correct_answer": correct_answer,
|
|
|
"is_correct": False,
|
|
|
"is_valid_format": False,
|
|
|
"uses_all_numbers": False,
|
|
|
"predicted_result": None,
|
|
|
"correct_result": correct_answer,
|
|
|
}
|
|
|
|
|
|
|
|
|
predicted_result, is_valid_predicted = safe_eval_expression(predicted_answer)
|
|
|
result["predicted_result"] = predicted_result
|
|
|
result["is_valid_format"] = is_valid_predicted
|
|
|
|
|
|
|
|
|
uses_all_numbers = check_numbers_usage(predicted_answer, nums)
|
|
|
result["uses_all_numbers"] = uses_all_numbers
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
f"Answered: {predicted_answer} - Predicted result: {predicted_result} - Correct result: {correct_answer} - Uses all numbers: {uses_all_numbers}"
|
|
|
)
|
|
|
|
|
|
|
|
|
if is_valid_predicted and predicted_result is not None and uses_all_numbers:
|
|
|
result["is_correct"] = abs(predicted_result - correct_answer) < 1e-6
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
def calculate_accuracy(
|
|
|
csv_path: str,
|
|
|
sft_model_path: str | None,
|
|
|
grpo_model_path: str | None,
|
|
|
base_model_id: str = "Qwen/Qwen2.5-Math-1.5B",
|
|
|
device: str = "auto",
|
|
|
dtype: torch.dtype = torch.float16,
|
|
|
max_new_tokens: int = 4096,
|
|
|
temperature: float = 0.7,
|
|
|
max_samples: int | None = None,
|
|
|
output_path: str | None = None,
|
|
|
) -> dict:
|
|
|
"""
|
|
|
Calculate accuracy of the model on the given dataset.
|
|
|
|
|
|
Args:
|
|
|
csv_path: Path to the CSV file with test data
|
|
|
sft_model_path: Path to the SFT model
|
|
|
grpo_model_path: Path to the GRPO model
|
|
|
base_model_id: Base model identifier
|
|
|
device: Device to run inference on
|
|
|
dtype: Data type for the model
|
|
|
max_new_tokens: Maximum tokens to generate
|
|
|
temperature: Sampling temperature
|
|
|
max_samples: Maximum number of samples to evaluate (None for all)
|
|
|
output_path: Path to save detailed results (optional)
|
|
|
|
|
|
Returns:
|
|
|
Dictionary with accuracy metrics
|
|
|
"""
|
|
|
|
|
|
df = load_csv_data(csv_path)
|
|
|
|
|
|
if max_samples is not None:
|
|
|
df = df.head(max_samples)
|
|
|
logger.info(f"Limiting evaluation to {max_samples} samples")
|
|
|
|
|
|
|
|
|
logger.info("Loading model...")
|
|
|
model_inference = GRPOModelInference(
|
|
|
sft_model_path=sft_model_path,
|
|
|
grpo_model_path=grpo_model_path,
|
|
|
base_model_id=base_model_id,
|
|
|
device=device,
|
|
|
dtype=dtype,
|
|
|
)
|
|
|
|
|
|
|
|
|
results = []
|
|
|
correct_predictions = 0
|
|
|
valid_format_predictions = 0
|
|
|
uses_all_numbers_predictions = 0
|
|
|
|
|
|
logger.info("Starting evaluation...")
|
|
|
pbar = tqdm(df.iterrows(), total=len(df), desc="Evaluating")
|
|
|
|
|
|
for idx, (_, row) in enumerate(pbar):
|
|
|
|
|
|
response, extracted_answer, _ = model_inference.solve_problem(
|
|
|
problem_description=row["problem_description"],
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
temperature=temperature,
|
|
|
)
|
|
|
|
|
|
|
|
|
nums = [row["num1"], row["num2"], row["num3"], row["num4"]]
|
|
|
evaluation = evaluate_prediction(
|
|
|
predicted_answer=extracted_answer,
|
|
|
correct_answer=row["correct_answer"],
|
|
|
nums=nums,
|
|
|
)
|
|
|
|
|
|
|
|
|
evaluation.update(
|
|
|
{
|
|
|
"id": row["id"],
|
|
|
"problem_description": row["problem_description"],
|
|
|
"full_response": response,
|
|
|
"nums": nums,
|
|
|
}
|
|
|
)
|
|
|
|
|
|
results.append(evaluation)
|
|
|
|
|
|
|
|
|
if evaluation["is_correct"]:
|
|
|
correct_predictions += 1
|
|
|
if evaluation["is_valid_format"]:
|
|
|
valid_format_predictions += 1
|
|
|
if evaluation["uses_all_numbers"]:
|
|
|
uses_all_numbers_predictions += 1
|
|
|
|
|
|
|
|
|
current_accuracy = correct_predictions / (idx + 1) if (idx + 1) > 0 else 0
|
|
|
current_valid_rate = (
|
|
|
valid_format_predictions / (idx + 1) if (idx + 1) > 0 else 0
|
|
|
)
|
|
|
pbar.set_postfix(
|
|
|
{
|
|
|
"Acc": f"{current_accuracy:.3f}",
|
|
|
"Valid": f"{current_valid_rate:.3f}",
|
|
|
"Correct": f"{correct_predictions}/{idx + 1}",
|
|
|
}
|
|
|
)
|
|
|
|
|
|
|
|
|
total_samples = len(results)
|
|
|
accuracy = correct_predictions / total_samples if total_samples > 0 else 0
|
|
|
valid_format_rate = (
|
|
|
valid_format_predictions / total_samples if total_samples > 0 else 0
|
|
|
)
|
|
|
uses_all_numbers_rate = (
|
|
|
uses_all_numbers_predictions / total_samples if total_samples > 0 else 0
|
|
|
)
|
|
|
|
|
|
metrics = {
|
|
|
"total_samples": total_samples,
|
|
|
"correct_predictions": correct_predictions,
|
|
|
"valid_format_predictions": valid_format_predictions,
|
|
|
"uses_all_numbers_predictions": uses_all_numbers_predictions,
|
|
|
"accuracy": accuracy,
|
|
|
"valid_format_rate": valid_format_rate,
|
|
|
"uses_all_numbers_rate": uses_all_numbers_rate,
|
|
|
}
|
|
|
|
|
|
|
|
|
logger.info("Evaluation completed!")
|
|
|
logger.info(f"Total samples: {total_samples}")
|
|
|
logger.info(f"Correct predictions: {correct_predictions}")
|
|
|
logger.info(f"Valid format predictions: {valid_format_predictions}")
|
|
|
logger.info(f"Uses all numbers predictions: {uses_all_numbers_predictions}")
|
|
|
logger.info(f"Accuracy: {accuracy:.4f} ({accuracy * 100:.2f}%)")
|
|
|
logger.info(
|
|
|
f"Valid format rate: {valid_format_rate:.4f} ({valid_format_rate * 100:.2f}%)"
|
|
|
)
|
|
|
logger.info(
|
|
|
f"Uses all numbers rate: {uses_all_numbers_rate:.4f} ({uses_all_numbers_rate * 100:.2f}%)"
|
|
|
)
|
|
|
|
|
|
|
|
|
if output_path:
|
|
|
results_df = pd.DataFrame(results)
|
|
|
results_df.to_csv(output_path, index=False)
|
|
|
logger.info(f"Detailed results saved to {output_path}")
|
|
|
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""Main function to run the accuracy calculation script."""
|
|
|
parser = argparse.ArgumentParser(
|
|
|
description="Calculate accuracy of GRPO model on arithmetic countdown problems"
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
"--csv_path",
|
|
|
type=str,
|
|
|
required=True,
|
|
|
default="data/grpo/test.csv",
|
|
|
help="Path to CSV file with test data",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--sft_model_path",
|
|
|
type=str,
|
|
|
default="models/sft/",
|
|
|
help="Path to SFT model directory",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--grpo_model_path",
|
|
|
type=str,
|
|
|
default="models/grpo/",
|
|
|
help="Path to GRPO model directory",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--base_model_id",
|
|
|
type=str,
|
|
|
default="Qwen/Qwen2.5-Math-1.5B",
|
|
|
help="Base model identifier",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--device", type=str, default="auto", help="Device to run inference on"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--max_new_tokens", type=int, default=4096, help="Maximum tokens to generate"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--temperature", type=float, default=1.0, help="Sampling temperature"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--max_samples",
|
|
|
type=int,
|
|
|
default=None,
|
|
|
help="Maximum number of samples to evaluate",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--output_path",
|
|
|
type=str,
|
|
|
default=None,
|
|
|
help="Path to save detailed results CSV",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--no_sft",
|
|
|
action="store_true",
|
|
|
help="Skip loading the SFT model (use only base model)",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--no_grpo",
|
|
|
action="store_true",
|
|
|
help="Skip loading the GRPO model (use only SFT model)",
|
|
|
)
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
dtype = torch.float16
|
|
|
|
|
|
|
|
|
metrics = calculate_accuracy(
|
|
|
csv_path=args.csv_path,
|
|
|
sft_model_path=args.sft_model_path if not args.no_sft else None,
|
|
|
grpo_model_path=args.grpo_model_path if not args.no_grpo else None,
|
|
|
base_model_id=args.base_model_id,
|
|
|
device=args.device,
|
|
|
dtype=dtype,
|
|
|
max_new_tokens=args.max_new_tokens,
|
|
|
temperature=args.temperature,
|
|
|
max_samples=args.max_samples,
|
|
|
output_path=args.output_path,
|
|
|
)
|
|
|
|
|
|
print("\n" + "=" * 50)
|
|
|
print("FINAL RESULTS")
|
|
|
print("=" * 50)
|
|
|
print(f"Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy'] * 100:.2f}%)")
|
|
|
print(
|
|
|
f"Valid Format Rate: {metrics['valid_format_rate']:.4f} ({metrics['valid_format_rate'] * 100:.2f}%)"
|
|
|
)
|
|
|
print(
|
|
|
f"Uses All Numbers Rate: {metrics['uses_all_numbers_rate']:.4f} ({metrics['uses_all_numbers_rate'] * 100:.2f}%)"
|
|
|
)
|
|
|
print(
|
|
|
f"Correct Predictions: {metrics['correct_predictions']}/{metrics['total_samples']}"
|
|
|
)
|
|
|
print(
|
|
|
f"Valid Format Predictions: {metrics['valid_format_predictions']}/{metrics['total_samples']}"
|
|
|
)
|
|
|
print(
|
|
|
f"Uses All Numbers Predictions: {metrics['uses_all_numbers_predictions']}/{metrics['total_samples']}"
|
|
|
)
|
|
|
print("=" * 50)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|