math500-bon-exercise / code /step3_best_of_n.py
cmpatino's picture
cmpatino HF Staff
Upload code/step3_best_of_n.py with huggingface_hub
312d5c4 verified
"""
Step 3: Compute Best-of-N accuracy with weighted selection.
Best-of-N weighted selection (from DeepMind 2408.03314, Section 5.1):
1. For each problem, we have N=16 solutions with PRM scores
2. Extract the final answer from each solution
3. Group solutions by their final answer string
4. Sum the PRM scores within each group (weighted vote)
5. Select the answer with the highest total weighted score
This is formally:
â = argmax_a Σᵢ 𝟙(aᵢ = a) · score(sᵢ)
Where score(sᵢ) is the PRM's last-step prediction for solution i.
Co-authored with Claude (Anthropic). I can explain all code logic.
"""
import json
from collections import defaultdict
def extract_boxed_solution(text):
"""Extract content of the last \\boxed{} in text."""
try:
start_index = text.rindex("\\boxed{")
content_start = start_index + 7
bracket_count = 1
current_pos = content_start
while bracket_count > 0 and current_pos < len(text):
if text[current_pos] == "{":
bracket_count += 1
elif text[current_pos] == "}":
bracket_count -= 1
current_pos += 1
if bracket_count == 0:
return text[content_start : current_pos - 1].strip()
return None
except (ValueError, Exception):
return None
def weighted_best_of_n(extracted_answers, prm_scores):
"""
Compute the Best-of-N answer using weighted selection.
Groups solutions by their extracted answer, sums PRM scores
per group, and returns the answer with the highest total score.
Args:
extracted_answers: list of N answer strings (may contain None)
prm_scores: list of N PRM scores (floats in [0,1])
Returns:
tuple: (best_answer, answer_scores_dict)
"""
answer_scores = defaultdict(float)
answer_counts = defaultdict(int)
for answer, score in zip(extracted_answers, prm_scores):
if answer is None:
# Skip solutions where we couldn't extract an answer
# (following DeepMind's filtering of unparseable solutions)
continue
answer_scores[answer] += score
answer_counts[answer] += 1
if not answer_scores:
return None, {}
# Select the answer with highest total weighted score
best_answer = max(answer_scores, key=answer_scores.get)
return best_answer, dict(answer_scores)
def standard_best_of_n(extracted_answers, prm_scores):
"""
Standard (non-weighted) Best-of-N: pick the single solution
with the highest PRM score and use its answer.
"""
best_idx = None
best_score = -1
for i, (answer, score) in enumerate(zip(extracted_answers, prm_scores)):
if answer is not None and score > best_score:
best_score = score
best_idx = i
if best_idx is not None:
return extracted_answers[best_idx]
return None
def majority_vote(extracted_answers):
"""
Pure majority vote (no reward weighting): pick the most frequent answer.
"""
counts = defaultdict(int)
for answer in extracted_answers:
if answer is not None:
counts[answer] += 1
if not counts:
return None
return max(counts, key=counts.get)
# ──────────────────────────────────────────────────────────────────────────────
# Load scored results
# ──────────────────────────────────────────────────────────────────────────────
print("=" * 70)
print("STEP 3: Computing Best-of-N accuracy with weighted selection")
print("=" * 70)
with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/scored_results.json") as f:
scored_results = json.load(f)
# Also load greedy results for comparison
with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/greedy_results.json") as f:
greedy_results = json.load(f)
# ──────────────────────────────────────────────────────────────────────────────
# Compute Best-of-N for each problem
# ──────────────────────────────────────────────────────────────────────────────
weighted_correct = 0
standard_correct = 0
majority_correct = 0
greedy_correct_count = 0
results_summary = []
for i, (scored, greedy) in enumerate(zip(scored_results, greedy_results)):
problem_id = scored["unique_id"]
ground_truth = scored["answer"]
# Extract answers from sampled solutions
extracted = scored["extracted_answers"]
scores = scored["prm_scores"]
# Weighted Best-of-N
weighted_answer, answer_scores = weighted_best_of_n(extracted, scores)
weighted_is_correct = (weighted_answer is not None) and (weighted_answer == ground_truth)
if weighted_is_correct:
weighted_correct += 1
# Standard Best-of-N (for comparison)
standard_answer = standard_best_of_n(extracted, scores)
standard_is_correct = (standard_answer is not None) and (standard_answer == ground_truth)
if standard_is_correct:
standard_correct += 1
# Majority vote (for comparison)
majority_answer = majority_vote(extracted)
majority_is_correct = (majority_answer is not None) and (majority_answer == ground_truth)
if majority_is_correct:
majority_correct += 1
# Greedy baseline
greedy_answer = greedy["greedy_extracted_answer"]
greedy_is_correct = greedy["greedy_correct"]
if greedy_is_correct:
greedy_correct_count += 1
# Count how many of the N solutions got the right answer
n_correct_in_sample = sum(1 for a in extracted if a == ground_truth)
# Summary for this problem
summary = {
"idx": i,
"unique_id": problem_id,
"level": scored["level"],
"subject": scored["subject"],
"ground_truth": ground_truth,
"greedy_answer": greedy_answer,
"greedy_correct": greedy_is_correct,
"weighted_bon_answer": weighted_answer,
"weighted_bon_correct": weighted_is_correct,
"standard_bon_answer": standard_answer,
"standard_bon_correct": standard_is_correct,
"majority_vote_answer": majority_answer,
"majority_vote_correct": majority_is_correct,
"n_correct_in_16": n_correct_in_sample,
"answer_score_breakdown": answer_scores,
"prm_scores": scores,
}
results_summary.append(summary)
# Print per-problem results
status_g = "✓" if greedy_is_correct else "✗"
status_w = "✓" if weighted_is_correct else "✗"
print(f"\n [{problem_id}] Level {scored['level']} | {scored['subject']}")
print(f" Ground truth: {ground_truth}")
print(f" Greedy {status_g}: {greedy_answer}")
print(f" Weighted BoN {status_w}: {weighted_answer}")
print(f" Correct in sample: {n_correct_in_sample}/{len(extracted)}")
if answer_scores:
print(f" Score breakdown: {dict(sorted(answer_scores.items(), key=lambda x: -x[1]))}")
# ──────────────────────────────────────────────────────────────────────────────
# Overall results
# ──────────────────────────────────────────────────────────────────────────────
n_problems = len(scored_results)
print("\n" + "=" * 70)
print("RESULTS SUMMARY")
print("=" * 70)
print(f" Greedy (N=1): {greedy_correct_count}/{n_problems} = {greedy_correct_count/n_problems:.1%}")
print(f" Majority Vote (N=16): {majority_correct}/{n_problems} = {majority_correct/n_problems:.1%}")
print(f" Standard Best-of-N (N=16): {standard_correct}/{n_problems} = {standard_correct/n_problems:.1%}")
print(f" Weighted Best-of-N (N=16): {weighted_correct}/{n_problems} = {weighted_correct/n_problems:.1%}")
# Save results
with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/bon_results.json", "w") as f:
json.dump(results_summary, f, indent=2)
print("\nSaved detailed results to outputs/bon_results.json")
# ──────────────────────────────────────────────────────────────────────────────
# Compute Best-of-N at various N values (using the N=16 sample)
# ──────────────────────────────────────────────────────────────────────────────
print("\n" + "=" * 70)
print("ANALYSIS: How accuracy varies with N")
print("=" * 70)
import random
random.seed(42)
n_values = [1, 2, 4, 8, 16]
n_trials = 50 # Average over multiple random subsets for N < 16
accuracy_by_n = {}
for n in n_values:
if n == 16:
# Use all solutions
correct = 0
for s in scored_results:
answer, _ = weighted_best_of_n(s["extracted_answers"], s["prm_scores"])
if answer == s["answer"]:
correct += 1
acc = correct / n_problems
else:
# Subsample and average over trials
trial_accs = []
for trial in range(n_trials):
correct = 0
for s in scored_results:
# Random subset of N solutions
indices = random.sample(range(16), n)
sub_answers = [s["extracted_answers"][j] for j in indices]
sub_scores = [s["prm_scores"][j] for j in indices]
answer, _ = weighted_best_of_n(sub_answers, sub_scores)
if answer == s["answer"]:
correct += 1
trial_accs.append(correct / n_problems)
acc = sum(trial_accs) / len(trial_accs)
accuracy_by_n[n] = acc
print(f" N={n:2d}: {acc:.1%}")
# Save accuracy-by-N for plotting
with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/accuracy_by_n.json", "w") as f:
json.dump(accuracy_by_n, f, indent=2)
print("\nDone! Results saved. Run step4_analysis.py for plots.")