File size: 13,583 Bytes
00db46c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
#!/usr/bin/env python3
"""

Script to calculate accuracy of a trained GRPO model on arithmetic countdown problems.



This script loads a CSV file with problem data, performs inference using the trained model,

and calculates the accuracy by comparing predicted answers with correct answers.

"""
import sys, os
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, "../.."))
sys.path.append(project_root)

import argparse
import logging
import re
import sys
from pathlib import Path

import pandas as pd
import torch
from tqdm import tqdm

# Add src to path for imports
sys.path.append(str(Path(__file__).parent.parent))

from src.utils.inference import GRPOModelInference

# Set up logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("calculate_accuracy")


def load_csv_data(csv_path: str) -> pd.DataFrame:
    """

    Load CSV data with the expected format.



    Expected columns: id, problem_description, correct_answer, num1, num2, num3, num4



    Args:

        csv_path: Path to the CSV file



    Returns:

        DataFrame with the loaded data

    """
    df = pd.read_csv(csv_path)

    # Verify required columns exist
    required_columns = [
        "id",
        "problem_description",
        "correct_answer",
        "num1",
        "num2",
        "num3",
        "num4",
    ]
    missing_columns = [col for col in required_columns if col not in df.columns]

    if missing_columns:
        raise ValueError(f"Missing required columns: {missing_columns}")

    logger.info(f"Loaded {len(df)} problems from {csv_path}")
    return df


def safe_eval_expression(expression: str) -> tuple[float | None, bool]:
    """

    Safely evaluate an arithmetic expression.



    Args:

        expression: The arithmetic expression to evaluate



    Returns:

        Tuple of (result, is_valid)

    """
    if not expression or not expression.strip():
        return None, False

    # Replace 'x' with '*' for evaluation if present
    normalized = expression.replace("x", "*").replace("X", "*")

    # Basic validation - only allow numbers, operators, spaces, and parentheses
    allowed_chars = set("0123456789+-*/.() ")
    if not all(c in allowed_chars for c in normalized):
        return None, False

    try:
        result = eval(normalized)
        return result, True
    except (SyntaxError, ValueError, ZeroDivisionError, NameError):
        return None, False


def check_numbers_usage(expression: str, required_numbers: list[int]) -> bool:
    """

    Check if the expression uses exactly the required numbers.



    Args:

        expression: The arithmetic expression to check

        required_numbers: List of numbers that should be used exactly once each



    Returns:

        True if expression uses all required numbers exactly once, False otherwise

    """
    if not expression or not expression.strip():
        return False

    # Extract all numbers from the expression
    numbers_in_expression = re.findall(r"\b\d+\b", expression)

    # Convert to integers
    try:
        numbers_in_expression = [int(num) for num in numbers_in_expression]
    except ValueError:
        return False

    # Sort both lists for comparison
    required_sorted = sorted(required_numbers)
    found_sorted = sorted(numbers_in_expression)

    return required_sorted == found_sorted


def evaluate_prediction(

    predicted_answer: str, correct_answer: int, nums: list[int]

) -> dict:
    """

    Evaluate a single prediction against the correct answer.



    Args:

        predicted_answer: The model's predicted arithmetic expression

        correct_answer: The correct integer result

        nums: List of four numbers used in the problem



    Returns:

        Dictionary with evaluation results

    """
    result = {
        "predicted_answer": predicted_answer,
        "correct_answer": correct_answer,
        "is_correct": False,
        "is_valid_format": False,
        "uses_all_numbers": False,
        "predicted_result": None,
        "correct_result": correct_answer,
    }

    # Evaluate predicted answer
    predicted_result, is_valid_predicted = safe_eval_expression(predicted_answer)
    result["predicted_result"] = predicted_result
    result["is_valid_format"] = is_valid_predicted

    # Check if all required numbers are used
    uses_all_numbers = check_numbers_usage(predicted_answer, nums)
    result["uses_all_numbers"] = uses_all_numbers

    # Log predicted and correct results
    logger.info(
        f"Answered: {predicted_answer} - Predicted result: {predicted_result} - Correct result: {correct_answer} - Uses all numbers: {uses_all_numbers}"
    )

    # Check if prediction is correct (must be valid format, use all numbers, and have correct result)
    if is_valid_predicted and predicted_result is not None and uses_all_numbers:
        result["is_correct"] = abs(predicted_result - correct_answer) < 1e-6

    return result


def calculate_accuracy(

    csv_path: str,

    sft_model_path: str | None,

    grpo_model_path: str | None,

    base_model_id: str = "Qwen/Qwen2.5-Math-1.5B",

    device: str = "auto",

    dtype: torch.dtype = torch.float16,

    max_new_tokens: int = 4096,

    temperature: float = 0.7,

    max_samples: int | None = None,

    output_path: str | None = None,

) -> dict:
    """

    Calculate accuracy of the model on the given dataset.



    Args:

        csv_path: Path to the CSV file with test data

        sft_model_path: Path to the SFT model

        grpo_model_path: Path to the GRPO model

        base_model_id: Base model identifier

        device: Device to run inference on

        dtype: Data type for the model

        max_new_tokens: Maximum tokens to generate

        temperature: Sampling temperature

        max_samples: Maximum number of samples to evaluate (None for all)

        output_path: Path to save detailed results (optional)



    Returns:

        Dictionary with accuracy metrics

    """
    # Load data
    df = load_csv_data(csv_path)

    if max_samples is not None:
        df = df.head(max_samples)
        logger.info(f"Limiting evaluation to {max_samples} samples")

    # Initialize model
    logger.info("Loading model...")
    model_inference = GRPOModelInference(
        sft_model_path=sft_model_path,
        grpo_model_path=grpo_model_path,
        base_model_id=base_model_id,
        device=device,
        dtype=dtype,
    )

    # Evaluate each problem
    results = []
    correct_predictions = 0
    valid_format_predictions = 0
    uses_all_numbers_predictions = 0

    logger.info("Starting evaluation...")
    pbar = tqdm(df.iterrows(), total=len(df), desc="Evaluating")

    for idx, (_, row) in enumerate(pbar):
        # Perform inference
        response, extracted_answer, _ = model_inference.solve_problem(
            problem_description=row["problem_description"],
            max_new_tokens=max_new_tokens,
            temperature=temperature,
        )

        # Evaluate prediction
        nums = [row["num1"], row["num2"], row["num3"], row["num4"]]
        evaluation = evaluate_prediction(
            predicted_answer=extracted_answer,
            correct_answer=row["correct_answer"],
            nums=nums,
        )

        # Add metadata
        evaluation.update(
            {
                "id": row["id"],
                "problem_description": row["problem_description"],
                "full_response": response,
                "nums": nums,
            }
        )

        results.append(evaluation)

        # Update counters
        if evaluation["is_correct"]:
            correct_predictions += 1
        if evaluation["is_valid_format"]:
            valid_format_predictions += 1
        if evaluation["uses_all_numbers"]:
            uses_all_numbers_predictions += 1

        # Update progress bar with intermediate results
        current_accuracy = correct_predictions / (idx + 1) if (idx + 1) > 0 else 0
        current_valid_rate = (
            valid_format_predictions / (idx + 1) if (idx + 1) > 0 else 0
        )
        pbar.set_postfix(
            {
                "Acc": f"{current_accuracy:.3f}",
                "Valid": f"{current_valid_rate:.3f}",
                "Correct": f"{correct_predictions}/{idx + 1}",
            }
        )

    # Calculate metrics
    total_samples = len(results)
    accuracy = correct_predictions / total_samples if total_samples > 0 else 0
    valid_format_rate = (
        valid_format_predictions / total_samples if total_samples > 0 else 0
    )
    uses_all_numbers_rate = (
        uses_all_numbers_predictions / total_samples if total_samples > 0 else 0
    )

    metrics = {
        "total_samples": total_samples,
        "correct_predictions": correct_predictions,
        "valid_format_predictions": valid_format_predictions,
        "uses_all_numbers_predictions": uses_all_numbers_predictions,
        "accuracy": accuracy,
        "valid_format_rate": valid_format_rate,
        "uses_all_numbers_rate": uses_all_numbers_rate,
    }

    # Log results
    logger.info("Evaluation completed!")
    logger.info(f"Total samples: {total_samples}")
    logger.info(f"Correct predictions: {correct_predictions}")
    logger.info(f"Valid format predictions: {valid_format_predictions}")
    logger.info(f"Uses all numbers predictions: {uses_all_numbers_predictions}")
    logger.info(f"Accuracy: {accuracy:.4f} ({accuracy * 100:.2f}%)")
    logger.info(
        f"Valid format rate: {valid_format_rate:.4f} ({valid_format_rate * 100:.2f}%)"
    )
    logger.info(
        f"Uses all numbers rate: {uses_all_numbers_rate:.4f} ({uses_all_numbers_rate * 100:.2f}%)"
    )

    # Save detailed results if requested
    if output_path:
        results_df = pd.DataFrame(results)
        results_df.to_csv(output_path, index=False)
        logger.info(f"Detailed results saved to {output_path}")

    return metrics


def main():
    """Main function to run the accuracy calculation script."""
    parser = argparse.ArgumentParser(
        description="Calculate accuracy of GRPO model on arithmetic countdown problems"
    )

    parser.add_argument(
        "--csv_path",
        type=str,
        required=True,
        default="data/grpo/test.csv",
        help="Path to CSV file with test data",
    )
    parser.add_argument(
        "--sft_model_path",
        type=str,
        default="models/sft/",
        help="Path to SFT model directory",
    )
    parser.add_argument(
        "--grpo_model_path",
        type=str,
        default="models/grpo/",
        help="Path to GRPO model directory",
    )
    parser.add_argument(
        "--base_model_id",
        type=str,
        default="Qwen/Qwen2.5-Math-1.5B",
        help="Base model identifier",
    )
    parser.add_argument(
        "--device", type=str, default="auto", help="Device to run inference on"
    )
    parser.add_argument(
        "--max_new_tokens", type=int, default=4096, help="Maximum tokens to generate"
    )
    parser.add_argument(
        "--temperature", type=float, default=1.0, help="Sampling temperature"
    )
    parser.add_argument(
        "--max_samples",
        type=int,
        default=None,
        help="Maximum number of samples to evaluate",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        default=None,
        help="Path to save detailed results CSV",
    )
    parser.add_argument(
        "--no_sft",
        action="store_true",
        help="Skip loading the SFT model (use only base model)",
    )
    parser.add_argument(
        "--no_grpo",
        action="store_true",
        help="Skip loading the GRPO model (use only SFT model)",
    )

    args = parser.parse_args()

    # Convert dtype
    dtype = torch.float16

    # Calculate accuracy
    metrics = calculate_accuracy(
        csv_path=args.csv_path,
        sft_model_path=args.sft_model_path if not args.no_sft else None,
        grpo_model_path=args.grpo_model_path if not args.no_grpo else None,
        base_model_id=args.base_model_id,
        device=args.device,
        dtype=dtype,
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        max_samples=args.max_samples,
        output_path=args.output_path,
    )

    print("\n" + "=" * 50)
    print("FINAL RESULTS")
    print("=" * 50)
    print(f"Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy'] * 100:.2f}%)")
    print(
        f"Valid Format Rate: {metrics['valid_format_rate']:.4f} ({metrics['valid_format_rate'] * 100:.2f}%)"
    )
    print(
        f"Uses All Numbers Rate: {metrics['uses_all_numbers_rate']:.4f} ({metrics['uses_all_numbers_rate'] * 100:.2f}%)"
    )
    print(
        f"Correct Predictions: {metrics['correct_predictions']}/{metrics['total_samples']}"
    )
    print(
        f"Valid Format Predictions: {metrics['valid_format_predictions']}/{metrics['total_samples']}"
    )
    print(
        f"Uses All Numbers Predictions: {metrics['uses_all_numbers_predictions']}/{metrics['total_samples']}"
    )
    print("=" * 50)


if __name__ == "__main__":
    main()