#!/usr/bin/env python3 """ Script to calculate accuracy of a trained GRPO model on arithmetic countdown problems. This script loads a CSV file with problem data, performs inference using the trained model, and calculates the accuracy by comparing predicted answers with correct answers. """ import sys, os current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.abspath(os.path.join(current_dir, "../..")) sys.path.append(project_root) import argparse import logging import re import sys from pathlib import Path import pandas as pd import torch from tqdm import tqdm # Add src to path for imports sys.path.append(str(Path(__file__).parent.parent)) from src.utils.inference import GRPOModelInference # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger("calculate_accuracy") def load_csv_data(csv_path: str) -> pd.DataFrame: """ Load CSV data with the expected format. Expected columns: id, problem_description, correct_answer, num1, num2, num3, num4 Args: csv_path: Path to the CSV file Returns: DataFrame with the loaded data """ df = pd.read_csv(csv_path) # Verify required columns exist required_columns = [ "id", "problem_description", "correct_answer", "num1", "num2", "num3", "num4", ] missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: raise ValueError(f"Missing required columns: {missing_columns}") logger.info(f"Loaded {len(df)} problems from {csv_path}") return df def safe_eval_expression(expression: str) -> tuple[float | None, bool]: """ Safely evaluate an arithmetic expression. Args: expression: The arithmetic expression to evaluate Returns: Tuple of (result, is_valid) """ if not expression or not expression.strip(): return None, False # Replace 'x' with '*' for evaluation if present normalized = expression.replace("x", "*").replace("X", "*") # Basic validation - only allow numbers, operators, spaces, and parentheses allowed_chars = set("0123456789+-*/.() ") if not all(c in allowed_chars for c in normalized): return None, False try: result = eval(normalized) return result, True except (SyntaxError, ValueError, ZeroDivisionError, NameError): return None, False def check_numbers_usage(expression: str, required_numbers: list[int]) -> bool: """ Check if the expression uses exactly the required numbers. Args: expression: The arithmetic expression to check required_numbers: List of numbers that should be used exactly once each Returns: True if expression uses all required numbers exactly once, False otherwise """ if not expression or not expression.strip(): return False # Extract all numbers from the expression numbers_in_expression = re.findall(r"\b\d+\b", expression) # Convert to integers try: numbers_in_expression = [int(num) for num in numbers_in_expression] except ValueError: return False # Sort both lists for comparison required_sorted = sorted(required_numbers) found_sorted = sorted(numbers_in_expression) return required_sorted == found_sorted def evaluate_prediction( predicted_answer: str, correct_answer: int, nums: list[int] ) -> dict: """ Evaluate a single prediction against the correct answer. Args: predicted_answer: The model's predicted arithmetic expression correct_answer: The correct integer result nums: List of four numbers used in the problem Returns: Dictionary with evaluation results """ result = { "predicted_answer": predicted_answer, "correct_answer": correct_answer, "is_correct": False, "is_valid_format": False, "uses_all_numbers": False, "predicted_result": None, "correct_result": correct_answer, } # Evaluate predicted answer predicted_result, is_valid_predicted = safe_eval_expression(predicted_answer) result["predicted_result"] = predicted_result result["is_valid_format"] = is_valid_predicted # Check if all required numbers are used uses_all_numbers = check_numbers_usage(predicted_answer, nums) result["uses_all_numbers"] = uses_all_numbers # Log predicted and correct results logger.info( f"Answered: {predicted_answer} - Predicted result: {predicted_result} - Correct result: {correct_answer} - Uses all numbers: {uses_all_numbers}" ) # Check if prediction is correct (must be valid format, use all numbers, and have correct result) if is_valid_predicted and predicted_result is not None and uses_all_numbers: result["is_correct"] = abs(predicted_result - correct_answer) < 1e-6 return result def calculate_accuracy( csv_path: str, sft_model_path: str | None, grpo_model_path: str | None, base_model_id: str = "Qwen/Qwen2.5-Math-1.5B", device: str = "auto", dtype: torch.dtype = torch.float16, max_new_tokens: int = 4096, temperature: float = 0.7, max_samples: int | None = None, output_path: str | None = None, ) -> dict: """ Calculate accuracy of the model on the given dataset. Args: csv_path: Path to the CSV file with test data sft_model_path: Path to the SFT model grpo_model_path: Path to the GRPO model base_model_id: Base model identifier device: Device to run inference on dtype: Data type for the model max_new_tokens: Maximum tokens to generate temperature: Sampling temperature max_samples: Maximum number of samples to evaluate (None for all) output_path: Path to save detailed results (optional) Returns: Dictionary with accuracy metrics """ # Load data df = load_csv_data(csv_path) if max_samples is not None: df = df.head(max_samples) logger.info(f"Limiting evaluation to {max_samples} samples") # Initialize model logger.info("Loading model...") model_inference = GRPOModelInference( sft_model_path=sft_model_path, grpo_model_path=grpo_model_path, base_model_id=base_model_id, device=device, dtype=dtype, ) # Evaluate each problem results = [] correct_predictions = 0 valid_format_predictions = 0 uses_all_numbers_predictions = 0 logger.info("Starting evaluation...") pbar = tqdm(df.iterrows(), total=len(df), desc="Evaluating") for idx, (_, row) in enumerate(pbar): # Perform inference response, extracted_answer, _ = model_inference.solve_problem( problem_description=row["problem_description"], max_new_tokens=max_new_tokens, temperature=temperature, ) # Evaluate prediction nums = [row["num1"], row["num2"], row["num3"], row["num4"]] evaluation = evaluate_prediction( predicted_answer=extracted_answer, correct_answer=row["correct_answer"], nums=nums, ) # Add metadata evaluation.update( { "id": row["id"], "problem_description": row["problem_description"], "full_response": response, "nums": nums, } ) results.append(evaluation) # Update counters if evaluation["is_correct"]: correct_predictions += 1 if evaluation["is_valid_format"]: valid_format_predictions += 1 if evaluation["uses_all_numbers"]: uses_all_numbers_predictions += 1 # Update progress bar with intermediate results current_accuracy = correct_predictions / (idx + 1) if (idx + 1) > 0 else 0 current_valid_rate = ( valid_format_predictions / (idx + 1) if (idx + 1) > 0 else 0 ) pbar.set_postfix( { "Acc": f"{current_accuracy:.3f}", "Valid": f"{current_valid_rate:.3f}", "Correct": f"{correct_predictions}/{idx + 1}", } ) # Calculate metrics total_samples = len(results) accuracy = correct_predictions / total_samples if total_samples > 0 else 0 valid_format_rate = ( valid_format_predictions / total_samples if total_samples > 0 else 0 ) uses_all_numbers_rate = ( uses_all_numbers_predictions / total_samples if total_samples > 0 else 0 ) metrics = { "total_samples": total_samples, "correct_predictions": correct_predictions, "valid_format_predictions": valid_format_predictions, "uses_all_numbers_predictions": uses_all_numbers_predictions, "accuracy": accuracy, "valid_format_rate": valid_format_rate, "uses_all_numbers_rate": uses_all_numbers_rate, } # Log results logger.info("Evaluation completed!") logger.info(f"Total samples: {total_samples}") logger.info(f"Correct predictions: {correct_predictions}") logger.info(f"Valid format predictions: {valid_format_predictions}") logger.info(f"Uses all numbers predictions: {uses_all_numbers_predictions}") logger.info(f"Accuracy: {accuracy:.4f} ({accuracy * 100:.2f}%)") logger.info( f"Valid format rate: {valid_format_rate:.4f} ({valid_format_rate * 100:.2f}%)" ) logger.info( f"Uses all numbers rate: {uses_all_numbers_rate:.4f} ({uses_all_numbers_rate * 100:.2f}%)" ) # Save detailed results if requested if output_path: results_df = pd.DataFrame(results) results_df.to_csv(output_path, index=False) logger.info(f"Detailed results saved to {output_path}") return metrics def main(): """Main function to run the accuracy calculation script.""" parser = argparse.ArgumentParser( description="Calculate accuracy of GRPO model on arithmetic countdown problems" ) parser.add_argument( "--csv_path", type=str, required=True, default="data/grpo/test.csv", help="Path to CSV file with test data", ) parser.add_argument( "--sft_model_path", type=str, default="models/sft/", help="Path to SFT model directory", ) parser.add_argument( "--grpo_model_path", type=str, default="models/grpo/", help="Path to GRPO model directory", ) parser.add_argument( "--base_model_id", type=str, default="Qwen/Qwen2.5-Math-1.5B", help="Base model identifier", ) parser.add_argument( "--device", type=str, default="auto", help="Device to run inference on" ) parser.add_argument( "--max_new_tokens", type=int, default=4096, help="Maximum tokens to generate" ) parser.add_argument( "--temperature", type=float, default=1.0, help="Sampling temperature" ) parser.add_argument( "--max_samples", type=int, default=None, help="Maximum number of samples to evaluate", ) parser.add_argument( "--output_path", type=str, default=None, help="Path to save detailed results CSV", ) parser.add_argument( "--no_sft", action="store_true", help="Skip loading the SFT model (use only base model)", ) parser.add_argument( "--no_grpo", action="store_true", help="Skip loading the GRPO model (use only SFT model)", ) args = parser.parse_args() # Convert dtype dtype = torch.float16 # Calculate accuracy metrics = calculate_accuracy( csv_path=args.csv_path, sft_model_path=args.sft_model_path if not args.no_sft else None, grpo_model_path=args.grpo_model_path if not args.no_grpo else None, base_model_id=args.base_model_id, device=args.device, dtype=dtype, max_new_tokens=args.max_new_tokens, temperature=args.temperature, max_samples=args.max_samples, output_path=args.output_path, ) print("\n" + "=" * 50) print("FINAL RESULTS") print("=" * 50) print(f"Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy'] * 100:.2f}%)") print( f"Valid Format Rate: {metrics['valid_format_rate']:.4f} ({metrics['valid_format_rate'] * 100:.2f}%)" ) print( f"Uses All Numbers Rate: {metrics['uses_all_numbers_rate']:.4f} ({metrics['uses_all_numbers_rate'] * 100:.2f}%)" ) print( f"Correct Predictions: {metrics['correct_predictions']}/{metrics['total_samples']}" ) print( f"Valid Format Predictions: {metrics['valid_format_predictions']}/{metrics['total_samples']}" ) print( f"Uses All Numbers Predictions: {metrics['uses_all_numbers_predictions']}/{metrics['total_samples']}" ) print("=" * 50) if __name__ == "__main__": main()