| """ |
| Step 3: Compute Best-of-N accuracy with weighted selection. |
| |
| Best-of-N weighted selection (from DeepMind 2408.03314, Section 5.1): |
| 1. For each problem, we have N=16 solutions with PRM scores |
| 2. Extract the final answer from each solution |
| 3. Group solutions by their final answer string |
| 4. Sum the PRM scores within each group (weighted vote) |
| 5. Select the answer with the highest total weighted score |
| |
| This is formally: |
| â = argmax_a Σᵢ 𝟙(aᵢ = a) · score(sᵢ) |
| |
| Where score(sᵢ) is the PRM's last-step prediction for solution i. |
| |
| Co-authored with Claude (Anthropic). I can explain all code logic. |
| """ |
|
|
| import json |
| from collections import defaultdict |
|
|
|
|
| def extract_boxed_solution(text): |
| """Extract content of the last \\boxed{} in text.""" |
| try: |
| start_index = text.rindex("\\boxed{") |
| content_start = start_index + 7 |
| bracket_count = 1 |
| current_pos = content_start |
| while bracket_count > 0 and current_pos < len(text): |
| if text[current_pos] == "{": |
| bracket_count += 1 |
| elif text[current_pos] == "}": |
| bracket_count -= 1 |
| current_pos += 1 |
| if bracket_count == 0: |
| return text[content_start : current_pos - 1].strip() |
| return None |
| except (ValueError, Exception): |
| return None |
|
|
|
|
| def weighted_best_of_n(extracted_answers, prm_scores): |
| """ |
| Compute the Best-of-N answer using weighted selection. |
| |
| Groups solutions by their extracted answer, sums PRM scores |
| per group, and returns the answer with the highest total score. |
| |
| Args: |
| extracted_answers: list of N answer strings (may contain None) |
| prm_scores: list of N PRM scores (floats in [0,1]) |
| |
| Returns: |
| tuple: (best_answer, answer_scores_dict) |
| """ |
| answer_scores = defaultdict(float) |
| answer_counts = defaultdict(int) |
|
|
| for answer, score in zip(extracted_answers, prm_scores): |
| if answer is None: |
| |
| |
| continue |
| answer_scores[answer] += score |
| answer_counts[answer] += 1 |
|
|
| if not answer_scores: |
| return None, {} |
|
|
| |
| best_answer = max(answer_scores, key=answer_scores.get) |
| return best_answer, dict(answer_scores) |
|
|
|
|
| def standard_best_of_n(extracted_answers, prm_scores): |
| """ |
| Standard (non-weighted) Best-of-N: pick the single solution |
| with the highest PRM score and use its answer. |
| """ |
| best_idx = None |
| best_score = -1 |
| for i, (answer, score) in enumerate(zip(extracted_answers, prm_scores)): |
| if answer is not None and score > best_score: |
| best_score = score |
| best_idx = i |
| if best_idx is not None: |
| return extracted_answers[best_idx] |
| return None |
|
|
|
|
| def majority_vote(extracted_answers): |
| """ |
| Pure majority vote (no reward weighting): pick the most frequent answer. |
| """ |
| counts = defaultdict(int) |
| for answer in extracted_answers: |
| if answer is not None: |
| counts[answer] += 1 |
| if not counts: |
| return None |
| return max(counts, key=counts.get) |
|
|
|
|
| |
| |
| |
| print("=" * 70) |
| print("STEP 3: Computing Best-of-N accuracy with weighted selection") |
| print("=" * 70) |
|
|
| with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/scored_results.json") as f: |
| scored_results = json.load(f) |
|
|
| |
| with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/greedy_results.json") as f: |
| greedy_results = json.load(f) |
|
|
| |
| |
| |
| weighted_correct = 0 |
| standard_correct = 0 |
| majority_correct = 0 |
| greedy_correct_count = 0 |
|
|
| results_summary = [] |
|
|
| for i, (scored, greedy) in enumerate(zip(scored_results, greedy_results)): |
| problem_id = scored["unique_id"] |
| ground_truth = scored["answer"] |
|
|
| |
| extracted = scored["extracted_answers"] |
| scores = scored["prm_scores"] |
|
|
| |
| weighted_answer, answer_scores = weighted_best_of_n(extracted, scores) |
| weighted_is_correct = (weighted_answer is not None) and (weighted_answer == ground_truth) |
| if weighted_is_correct: |
| weighted_correct += 1 |
|
|
| |
| standard_answer = standard_best_of_n(extracted, scores) |
| standard_is_correct = (standard_answer is not None) and (standard_answer == ground_truth) |
| if standard_is_correct: |
| standard_correct += 1 |
|
|
| |
| majority_answer = majority_vote(extracted) |
| majority_is_correct = (majority_answer is not None) and (majority_answer == ground_truth) |
| if majority_is_correct: |
| majority_correct += 1 |
|
|
| |
| greedy_answer = greedy["greedy_extracted_answer"] |
| greedy_is_correct = greedy["greedy_correct"] |
| if greedy_is_correct: |
| greedy_correct_count += 1 |
|
|
| |
| n_correct_in_sample = sum(1 for a in extracted if a == ground_truth) |
|
|
| |
| summary = { |
| "idx": i, |
| "unique_id": problem_id, |
| "level": scored["level"], |
| "subject": scored["subject"], |
| "ground_truth": ground_truth, |
| "greedy_answer": greedy_answer, |
| "greedy_correct": greedy_is_correct, |
| "weighted_bon_answer": weighted_answer, |
| "weighted_bon_correct": weighted_is_correct, |
| "standard_bon_answer": standard_answer, |
| "standard_bon_correct": standard_is_correct, |
| "majority_vote_answer": majority_answer, |
| "majority_vote_correct": majority_is_correct, |
| "n_correct_in_16": n_correct_in_sample, |
| "answer_score_breakdown": answer_scores, |
| "prm_scores": scores, |
| } |
| results_summary.append(summary) |
|
|
| |
| status_g = "✓" if greedy_is_correct else "✗" |
| status_w = "✓" if weighted_is_correct else "✗" |
| print(f"\n [{problem_id}] Level {scored['level']} | {scored['subject']}") |
| print(f" Ground truth: {ground_truth}") |
| print(f" Greedy {status_g}: {greedy_answer}") |
| print(f" Weighted BoN {status_w}: {weighted_answer}") |
| print(f" Correct in sample: {n_correct_in_sample}/{len(extracted)}") |
| if answer_scores: |
| print(f" Score breakdown: {dict(sorted(answer_scores.items(), key=lambda x: -x[1]))}") |
|
|
| |
| |
| |
| n_problems = len(scored_results) |
| print("\n" + "=" * 70) |
| print("RESULTS SUMMARY") |
| print("=" * 70) |
| print(f" Greedy (N=1): {greedy_correct_count}/{n_problems} = {greedy_correct_count/n_problems:.1%}") |
| print(f" Majority Vote (N=16): {majority_correct}/{n_problems} = {majority_correct/n_problems:.1%}") |
| print(f" Standard Best-of-N (N=16): {standard_correct}/{n_problems} = {standard_correct/n_problems:.1%}") |
| print(f" Weighted Best-of-N (N=16): {weighted_correct}/{n_problems} = {weighted_correct/n_problems:.1%}") |
|
|
| |
| with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/bon_results.json", "w") as f: |
| json.dump(results_summary, f, indent=2) |
| print("\nSaved detailed results to outputs/bon_results.json") |
|
|
| |
| |
| |
| print("\n" + "=" * 70) |
| print("ANALYSIS: How accuracy varies with N") |
| print("=" * 70) |
|
|
| import random |
| random.seed(42) |
|
|
| n_values = [1, 2, 4, 8, 16] |
| n_trials = 50 |
|
|
| accuracy_by_n = {} |
| for n in n_values: |
| if n == 16: |
| |
| correct = 0 |
| for s in scored_results: |
| answer, _ = weighted_best_of_n(s["extracted_answers"], s["prm_scores"]) |
| if answer == s["answer"]: |
| correct += 1 |
| acc = correct / n_problems |
| else: |
| |
| trial_accs = [] |
| for trial in range(n_trials): |
| correct = 0 |
| for s in scored_results: |
| |
| indices = random.sample(range(16), n) |
| sub_answers = [s["extracted_answers"][j] for j in indices] |
| sub_scores = [s["prm_scores"][j] for j in indices] |
| answer, _ = weighted_best_of_n(sub_answers, sub_scores) |
| if answer == s["answer"]: |
| correct += 1 |
| trial_accs.append(correct / n_problems) |
| acc = sum(trial_accs) / len(trial_accs) |
|
|
| accuracy_by_n[n] = acc |
| print(f" N={n:2d}: {acc:.1%}") |
|
|
| |
| with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/accuracy_by_n.json", "w") as f: |
| json.dump(accuracy_by_n, f, indent=2) |
|
|
| print("\nDone! Results saved. Run step4_analysis.py for plots.") |
|
|