""" Step 3: Compute Best-of-N accuracy with weighted selection. Best-of-N weighted selection (from DeepMind 2408.03314, Section 5.1): 1. For each problem, we have N=16 solutions with PRM scores 2. Extract the final answer from each solution 3. Group solutions by their final answer string 4. Sum the PRM scores within each group (weighted vote) 5. Select the answer with the highest total weighted score This is formally: รข = argmax_a ฮฃแตข ๐Ÿ™(aแตข = a) ยท score(sแตข) Where score(sแตข) is the PRM's last-step prediction for solution i. Co-authored with Claude (Anthropic). I can explain all code logic. """ import json from collections import defaultdict def extract_boxed_solution(text): """Extract content of the last \\boxed{} in text.""" try: start_index = text.rindex("\\boxed{") content_start = start_index + 7 bracket_count = 1 current_pos = content_start while bracket_count > 0 and current_pos < len(text): if text[current_pos] == "{": bracket_count += 1 elif text[current_pos] == "}": bracket_count -= 1 current_pos += 1 if bracket_count == 0: return text[content_start : current_pos - 1].strip() return None except (ValueError, Exception): return None def weighted_best_of_n(extracted_answers, prm_scores): """ Compute the Best-of-N answer using weighted selection. Groups solutions by their extracted answer, sums PRM scores per group, and returns the answer with the highest total score. Args: extracted_answers: list of N answer strings (may contain None) prm_scores: list of N PRM scores (floats in [0,1]) Returns: tuple: (best_answer, answer_scores_dict) """ answer_scores = defaultdict(float) answer_counts = defaultdict(int) for answer, score in zip(extracted_answers, prm_scores): if answer is None: # Skip solutions where we couldn't extract an answer # (following DeepMind's filtering of unparseable solutions) continue answer_scores[answer] += score answer_counts[answer] += 1 if not answer_scores: return None, {} # Select the answer with highest total weighted score best_answer = max(answer_scores, key=answer_scores.get) return best_answer, dict(answer_scores) def standard_best_of_n(extracted_answers, prm_scores): """ Standard (non-weighted) Best-of-N: pick the single solution with the highest PRM score and use its answer. """ best_idx = None best_score = -1 for i, (answer, score) in enumerate(zip(extracted_answers, prm_scores)): if answer is not None and score > best_score: best_score = score best_idx = i if best_idx is not None: return extracted_answers[best_idx] return None def majority_vote(extracted_answers): """ Pure majority vote (no reward weighting): pick the most frequent answer. """ counts = defaultdict(int) for answer in extracted_answers: if answer is not None: counts[answer] += 1 if not counts: return None return max(counts, key=counts.get) # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ # Load scored results # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ print("=" * 70) print("STEP 3: Computing Best-of-N accuracy with weighted selection") print("=" * 70) with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/scored_results.json") as f: scored_results = json.load(f) # Also load greedy results for comparison with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/greedy_results.json") as f: greedy_results = json.load(f) # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ # Compute Best-of-N for each problem # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ weighted_correct = 0 standard_correct = 0 majority_correct = 0 greedy_correct_count = 0 results_summary = [] for i, (scored, greedy) in enumerate(zip(scored_results, greedy_results)): problem_id = scored["unique_id"] ground_truth = scored["answer"] # Extract answers from sampled solutions extracted = scored["extracted_answers"] scores = scored["prm_scores"] # Weighted Best-of-N weighted_answer, answer_scores = weighted_best_of_n(extracted, scores) weighted_is_correct = (weighted_answer is not None) and (weighted_answer == ground_truth) if weighted_is_correct: weighted_correct += 1 # Standard Best-of-N (for comparison) standard_answer = standard_best_of_n(extracted, scores) standard_is_correct = (standard_answer is not None) and (standard_answer == ground_truth) if standard_is_correct: standard_correct += 1 # Majority vote (for comparison) majority_answer = majority_vote(extracted) majority_is_correct = (majority_answer is not None) and (majority_answer == ground_truth) if majority_is_correct: majority_correct += 1 # Greedy baseline greedy_answer = greedy["greedy_extracted_answer"] greedy_is_correct = greedy["greedy_correct"] if greedy_is_correct: greedy_correct_count += 1 # Count how many of the N solutions got the right answer n_correct_in_sample = sum(1 for a in extracted if a == ground_truth) # Summary for this problem summary = { "idx": i, "unique_id": problem_id, "level": scored["level"], "subject": scored["subject"], "ground_truth": ground_truth, "greedy_answer": greedy_answer, "greedy_correct": greedy_is_correct, "weighted_bon_answer": weighted_answer, "weighted_bon_correct": weighted_is_correct, "standard_bon_answer": standard_answer, "standard_bon_correct": standard_is_correct, "majority_vote_answer": majority_answer, "majority_vote_correct": majority_is_correct, "n_correct_in_16": n_correct_in_sample, "answer_score_breakdown": answer_scores, "prm_scores": scores, } results_summary.append(summary) # Print per-problem results status_g = "โœ“" if greedy_is_correct else "โœ—" status_w = "โœ“" if weighted_is_correct else "โœ—" print(f"\n [{problem_id}] Level {scored['level']} | {scored['subject']}") print(f" Ground truth: {ground_truth}") print(f" Greedy {status_g}: {greedy_answer}") print(f" Weighted BoN {status_w}: {weighted_answer}") print(f" Correct in sample: {n_correct_in_sample}/{len(extracted)}") if answer_scores: print(f" Score breakdown: {dict(sorted(answer_scores.items(), key=lambda x: -x[1]))}") # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ # Overall results # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ n_problems = len(scored_results) print("\n" + "=" * 70) print("RESULTS SUMMARY") print("=" * 70) print(f" Greedy (N=1): {greedy_correct_count}/{n_problems} = {greedy_correct_count/n_problems:.1%}") print(f" Majority Vote (N=16): {majority_correct}/{n_problems} = {majority_correct/n_problems:.1%}") print(f" Standard Best-of-N (N=16): {standard_correct}/{n_problems} = {standard_correct/n_problems:.1%}") print(f" Weighted Best-of-N (N=16): {weighted_correct}/{n_problems} = {weighted_correct/n_problems:.1%}") # Save results with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/bon_results.json", "w") as f: json.dump(results_summary, f, indent=2) print("\nSaved detailed results to outputs/bon_results.json") # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ # Compute Best-of-N at various N values (using the N=16 sample) # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ print("\n" + "=" * 70) print("ANALYSIS: How accuracy varies with N") print("=" * 70) import random random.seed(42) n_values = [1, 2, 4, 8, 16] n_trials = 50 # Average over multiple random subsets for N < 16 accuracy_by_n = {} for n in n_values: if n == 16: # Use all solutions correct = 0 for s in scored_results: answer, _ = weighted_best_of_n(s["extracted_answers"], s["prm_scores"]) if answer == s["answer"]: correct += 1 acc = correct / n_problems else: # Subsample and average over trials trial_accs = [] for trial in range(n_trials): correct = 0 for s in scored_results: # Random subset of N solutions indices = random.sample(range(16), n) sub_answers = [s["extracted_answers"][j] for j in indices] sub_scores = [s["prm_scores"][j] for j in indices] answer, _ = weighted_best_of_n(sub_answers, sub_scores) if answer == s["answer"]: correct += 1 trial_accs.append(correct / n_problems) acc = sum(trial_accs) / len(trial_accs) accuracy_by_n[n] = acc print(f" N={n:2d}: {acc:.1%}") # Save accuracy-by-N for plotting with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/accuracy_by_n.json", "w") as f: json.dump(accuracy_by_n, f, indent=2) print("\nDone! Results saved. Run step4_analysis.py for plots.")