""" Step 4: Analysis and visualization of Best-of-N vs greedy performance. This script creates plots comparing: 1. Overall accuracy: Greedy vs Majority Vote vs Standard BoN vs Weighted BoN 2. Accuracy vs N (how performance scales with number of samples) 3. Per-problem analysis: which problems did BoN solve that greedy couldn't? 4. PRM score distribution analysis Co-authored with Claude (Anthropic). I can explain all code logic. """ import json import matplotlib.pyplot as plt import matplotlib import numpy as np from collections import defaultdict matplotlib.rcParams.update({"font.size": 11, "figure.dpi": 150}) # ────────────────────────────────────────────────────────────────────────────── # Load results # ────────────────────────────────────────────────────────────────────────────── with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/bon_results.json") as f: bon_results = json.load(f) with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/accuracy_by_n.json") as f: accuracy_by_n = json.load(f) with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/scored_results.json") as f: scored_results = json.load(f) n_problems = len(bon_results) # ────────────────────────────────────────────────────────────────────────────── # Plot 1: Overall accuracy comparison (bar chart) # ────────────────────────────────────────────────────────────────────────────── fig, ax = plt.subplots(figsize=(8, 5)) methods = ["Greedy\n(N=1)", "Majority Vote\n(N=16)", "Standard BoN\n(N=16)", "Weighted BoN\n(N=16)"] accuracies = [ sum(r["greedy_correct"] for r in bon_results) / n_problems, sum(r["majority_vote_correct"] for r in bon_results) / n_problems, sum(r["standard_bon_correct"] for r in bon_results) / n_problems, sum(r["weighted_bon_correct"] for r in bon_results) / n_problems, ] colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2"] bars = ax.bar(methods, accuracies, color=colors, edgecolor="white", linewidth=1.5) for bar, acc in zip(bars, accuracies): ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01, f"{acc:.0%}", ha="center", va="bottom", fontweight="bold", fontsize=12) ax.set_ylabel("Accuracy") ax.set_title("Math Problem Accuracy: Greedy vs Best-of-N Methods\n(20 MATH-500 problems, Levels 1-3)") ax.set_ylim(0, 1.05) ax.grid(axis="y", alpha=0.3) plt.tight_layout() plt.savefig("/Users/cmpatino/Projects/ml-intern/exercise/outputs/plot1_accuracy_comparison.png") plt.close() print("Saved plot1_accuracy_comparison.png") # ────────────────────────────────────────────────────────────────────────────── # Plot 2: Accuracy vs N # ────────────────────────────────────────────────────────────────────────────── fig, ax = plt.subplots(figsize=(7, 5)) ns = sorted([int(k) for k in accuracy_by_n.keys()]) accs = [accuracy_by_n[str(n)] for n in ns] ax.plot(ns, accs, "o-", color="#8172B2", linewidth=2, markersize=8, label="Weighted BoN") # Add greedy baseline as horizontal line greedy_acc = sum(r["greedy_correct"] for r in bon_results) / n_problems ax.axhline(y=greedy_acc, color="#4C72B0", linestyle="--", linewidth=1.5, label=f"Greedy baseline ({greedy_acc:.0%})") for n, acc in zip(ns, accs): ax.annotate(f"{acc:.0%}", (n, acc), textcoords="offset points", xytext=(0, 10), ha="center", fontsize=10) ax.set_xlabel("N (number of samples)") ax.set_ylabel("Accuracy") ax.set_title("Weighted Best-of-N Accuracy vs Number of Samples") ax.set_xticks(ns) ax.set_ylim(0, 1.05) ax.legend() ax.grid(alpha=0.3) plt.tight_layout() plt.savefig("/Users/cmpatino/Projects/ml-intern/exercise/outputs/plot2_accuracy_vs_n.png") plt.close() print("Saved plot2_accuracy_vs_n.png") # ────────────────────────────────────────────────────────────────────────────── # Plot 3: Per-problem comparison (Greedy vs Weighted BoN) # ────────────────────────────────────────────────────────────────────────────── fig, ax = plt.subplots(figsize=(12, 5)) # Categorize problems categories = { "Both correct": [], "Only BoN correct": [], "Only Greedy correct": [], "Both wrong": [], } for r in bon_results: g = r["greedy_correct"] b = r["weighted_bon_correct"] label = f"L{r['level']}: {r['unique_id'].split('/')[-1][:15]}" if g and b: categories["Both correct"].append(label) elif not g and b: categories["Only BoN correct"].append(label) elif g and not b: categories["Only Greedy correct"].append(label) else: categories["Both wrong"].append(label) # Color map for the stacked bars cat_colors = { "Both correct": "#55A868", "Only BoN correct": "#8172B2", "Only Greedy correct": "#C44E52", "Both wrong": "#CCCCCC", } # Create a categorical overview labels = [] colors_list = [] for r in bon_results: g = r["greedy_correct"] b = r["weighted_bon_correct"] label = f"L{r['level']}" labels.append(label) if g and b: colors_list.append(cat_colors["Both correct"]) elif not g and b: colors_list.append(cat_colors["Only BoN correct"]) elif g and not b: colors_list.append(cat_colors["Only Greedy correct"]) else: colors_list.append(cat_colors["Both wrong"]) x = range(len(bon_results)) # Plot n_correct_in_16 as bar height, colored by category heights = [r["n_correct_in_16"] for r in bon_results] ax.bar(x, heights, color=colors_list, edgecolor="white", linewidth=0.5) # Add problem labels ax.set_xticks(x) short_ids = [r["unique_id"].split("/")[-1].replace(".json", "")[:12] for r in bon_results] ax.set_xticklabels(short_ids, rotation=45, ha="right", fontsize=8) ax.set_ylabel("# Correct Solutions (out of 16)") ax.set_title("Per-Problem Analysis: Correct Solutions in N=16 Sample") # Legend from matplotlib.patches import Patch legend_elements = [Patch(facecolor=c, label=l) for l, c in cat_colors.items()] ax.legend(handles=legend_elements, loc="upper right", fontsize=9) ax.grid(axis="y", alpha=0.3) plt.tight_layout() plt.savefig("/Users/cmpatino/Projects/ml-intern/exercise/outputs/plot3_per_problem.png") plt.close() print("Saved plot3_per_problem.png") # ────────────────────────────────────────────────────────────────────────────── # Plot 4: PRM Score Distribution (correct vs incorrect solutions) # ────────────────────────────────────────────────────────────────────────────── fig, ax = plt.subplots(figsize=(7, 5)) correct_scores = [] incorrect_scores = [] for r in scored_results: for answer, score in zip(r["extracted_answers"], r["prm_scores"]): if answer == r["answer"]: correct_scores.append(score) else: incorrect_scores.append(score) bins = np.linspace(0, 1, 25) ax.hist(correct_scores, bins=bins, alpha=0.7, label=f"Correct ({len(correct_scores)})", color="#55A868") ax.hist(incorrect_scores, bins=bins, alpha=0.7, label=f"Incorrect ({len(incorrect_scores)})", color="#C44E52") ax.set_xlabel("PRM Last-Step Score") ax.set_ylabel("Count") ax.set_title("PRM Score Distribution: Correct vs Incorrect Solutions") ax.legend() ax.grid(alpha=0.3) plt.tight_layout() plt.savefig("/Users/cmpatino/Projects/ml-intern/exercise/outputs/plot4_prm_scores.png") plt.close() print("Saved plot4_prm_scores.png") # ────────────────────────────────────────────────────────────────────────────── # Print detailed analysis # ────────────────────────────────────────────────────────────────────────────── print("\n" + "=" * 70) print("DETAILED ANALYSIS") print("=" * 70) print(f"\nOverall Accuracies:") print(f" Greedy (N=1): {accuracies[0]:.0%}") print(f" Majority Vote (N=16): {accuracies[1]:.0%}") print(f" Standard Best-of-N (N=16): {accuracies[2]:.0%}") print(f" Weighted Best-of-N (N=16): {accuracies[3]:.0%}") print(f"\nProblems ONLY solved by Weighted BoN (not greedy):") for r in bon_results: if r["weighted_bon_correct"] and not r["greedy_correct"]: print(f" - {r['unique_id']} (Level {r['level']}, {r['subject']})") print(f" Ground truth: {r['ground_truth']}") print(f" Greedy answer: {r['greedy_answer']}") print(f" BoN answer: {r['weighted_bon_answer']}") print(f" Correct in sample: {r['n_correct_in_16']}/16") print(f"\nProblems ONLY solved by Greedy (not BoN):") for r in bon_results: if r["greedy_correct"] and not r["weighted_bon_correct"]: print(f" - {r['unique_id']} (Level {r['level']}, {r['subject']})") print(f" Ground truth: {r['ground_truth']}") print(f" Greedy answer: {r['greedy_answer']}") print(f" BoN answer: {r['weighted_bon_answer']}") print(f" Correct in sample: {r['n_correct_in_16']}/16") print(f"\nProblems neither method solved:") for r in bon_results: if not r["greedy_correct"] and not r["weighted_bon_correct"]: print(f" - {r['unique_id']} (Level {r['level']}, {r['subject']})") print(f" Ground truth: {r['ground_truth']}") print(f" Correct in sample: {r['n_correct_in_16']}/16") # PRM Score stats print(f"\nPRM Score Statistics:") print(f" Correct solutions: mean={np.mean(correct_scores):.3f}, median={np.median(correct_scores):.3f}") print(f" Incorrect solutions: mean={np.mean(incorrect_scores):.3f}, median={np.median(incorrect_scores):.3f}") # Accuracy by level print(f"\nAccuracy by problem level:") for level in sorted(set(r["level"] for r in bon_results)): level_results = [r for r in bon_results if r["level"] == level] n = len(level_results) g = sum(r["greedy_correct"] for r in level_results) w = sum(r["weighted_bon_correct"] for r in level_results) print(f" Level {level}: Greedy {g}/{n} ({g/n:.0%}) | Weighted BoN {w}/{n} ({w/n:.0%})") # Accuracy by subject print(f"\nAccuracy by subject:") subjects = sorted(set(r["subject"] for r in bon_results)) for subj in subjects: subj_results = [r for r in bon_results if r["subject"] == subj] n = len(subj_results) g = sum(r["greedy_correct"] for r in subj_results) w = sum(r["weighted_bon_correct"] for r in subj_results) print(f" {subj}: Greedy {g}/{n} | Weighted BoN {w}/{n}") print("\nAll plots saved to outputs/ directory.")