File size: 10,740 Bytes
312d5c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
"""
Step 3: Compute Best-of-N accuracy with weighted selection.

Best-of-N weighted selection (from DeepMind 2408.03314, Section 5.1):
1. For each problem, we have N=16 solutions with PRM scores
2. Extract the final answer from each solution
3. Group solutions by their final answer string
4. Sum the PRM scores within each group (weighted vote)
5. Select the answer with the highest total weighted score

This is formally:
    â = argmax_a  Σᵢ 𝟙(aᵢ = a) · score(sᵢ)

Where score(sᵢ) is the PRM's last-step prediction for solution i.

Co-authored with Claude (Anthropic). I can explain all code logic.
"""

import json
from collections import defaultdict


def extract_boxed_solution(text):
    """Extract content of the last \\boxed{} in text."""
    try:
        start_index = text.rindex("\\boxed{")
        content_start = start_index + 7
        bracket_count = 1
        current_pos = content_start
        while bracket_count > 0 and current_pos < len(text):
            if text[current_pos] == "{":
                bracket_count += 1
            elif text[current_pos] == "}":
                bracket_count -= 1
            current_pos += 1
        if bracket_count == 0:
            return text[content_start : current_pos - 1].strip()
        return None
    except (ValueError, Exception):
        return None


def weighted_best_of_n(extracted_answers, prm_scores):
    """
    Compute the Best-of-N answer using weighted selection.

    Groups solutions by their extracted answer, sums PRM scores
    per group, and returns the answer with the highest total score.

    Args:
        extracted_answers: list of N answer strings (may contain None)
        prm_scores: list of N PRM scores (floats in [0,1])

    Returns:
        tuple: (best_answer, answer_scores_dict)
    """
    answer_scores = defaultdict(float)
    answer_counts = defaultdict(int)

    for answer, score in zip(extracted_answers, prm_scores):
        if answer is None:
            # Skip solutions where we couldn't extract an answer
            # (following DeepMind's filtering of unparseable solutions)
            continue
        answer_scores[answer] += score
        answer_counts[answer] += 1

    if not answer_scores:
        return None, {}

    # Select the answer with highest total weighted score
    best_answer = max(answer_scores, key=answer_scores.get)
    return best_answer, dict(answer_scores)


def standard_best_of_n(extracted_answers, prm_scores):
    """
    Standard (non-weighted) Best-of-N: pick the single solution
    with the highest PRM score and use its answer.
    """
    best_idx = None
    best_score = -1
    for i, (answer, score) in enumerate(zip(extracted_answers, prm_scores)):
        if answer is not None and score > best_score:
            best_score = score
            best_idx = i
    if best_idx is not None:
        return extracted_answers[best_idx]
    return None


def majority_vote(extracted_answers):
    """
    Pure majority vote (no reward weighting): pick the most frequent answer.
    """
    counts = defaultdict(int)
    for answer in extracted_answers:
        if answer is not None:
            counts[answer] += 1
    if not counts:
        return None
    return max(counts, key=counts.get)


# ──────────────────────────────────────────────────────────────────────────────
# Load scored results
# ──────────────────────────────────────────────────────────────────────────────
print("=" * 70)
print("STEP 3: Computing Best-of-N accuracy with weighted selection")
print("=" * 70)

with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/scored_results.json") as f:
    scored_results = json.load(f)

# Also load greedy results for comparison
with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/greedy_results.json") as f:
    greedy_results = json.load(f)

# ──────────────────────────────────────────────────────────────────────────────
# Compute Best-of-N for each problem
# ──────────────────────────────────────────────────────────────────────────────
weighted_correct = 0
standard_correct = 0
majority_correct = 0
greedy_correct_count = 0

results_summary = []

for i, (scored, greedy) in enumerate(zip(scored_results, greedy_results)):
    problem_id = scored["unique_id"]
    ground_truth = scored["answer"]

    # Extract answers from sampled solutions
    extracted = scored["extracted_answers"]
    scores = scored["prm_scores"]

    # Weighted Best-of-N
    weighted_answer, answer_scores = weighted_best_of_n(extracted, scores)
    weighted_is_correct = (weighted_answer is not None) and (weighted_answer == ground_truth)
    if weighted_is_correct:
        weighted_correct += 1

    # Standard Best-of-N (for comparison)
    standard_answer = standard_best_of_n(extracted, scores)
    standard_is_correct = (standard_answer is not None) and (standard_answer == ground_truth)
    if standard_is_correct:
        standard_correct += 1

    # Majority vote (for comparison)
    majority_answer = majority_vote(extracted)
    majority_is_correct = (majority_answer is not None) and (majority_answer == ground_truth)
    if majority_is_correct:
        majority_correct += 1

    # Greedy baseline
    greedy_answer = greedy["greedy_extracted_answer"]
    greedy_is_correct = greedy["greedy_correct"]
    if greedy_is_correct:
        greedy_correct_count += 1

    # Count how many of the N solutions got the right answer
    n_correct_in_sample = sum(1 for a in extracted if a == ground_truth)

    # Summary for this problem
    summary = {
        "idx": i,
        "unique_id": problem_id,
        "level": scored["level"],
        "subject": scored["subject"],
        "ground_truth": ground_truth,
        "greedy_answer": greedy_answer,
        "greedy_correct": greedy_is_correct,
        "weighted_bon_answer": weighted_answer,
        "weighted_bon_correct": weighted_is_correct,
        "standard_bon_answer": standard_answer,
        "standard_bon_correct": standard_is_correct,
        "majority_vote_answer": majority_answer,
        "majority_vote_correct": majority_is_correct,
        "n_correct_in_16": n_correct_in_sample,
        "answer_score_breakdown": answer_scores,
        "prm_scores": scores,
    }
    results_summary.append(summary)

    # Print per-problem results
    status_g = "✓" if greedy_is_correct else "✗"
    status_w = "✓" if weighted_is_correct else "✗"
    print(f"\n  [{problem_id}] Level {scored['level']} | {scored['subject']}")
    print(f"    Ground truth: {ground_truth}")
    print(f"    Greedy {status_g}: {greedy_answer}")
    print(f"    Weighted BoN {status_w}: {weighted_answer}")
    print(f"    Correct in sample: {n_correct_in_sample}/{len(extracted)}")
    if answer_scores:
        print(f"    Score breakdown: {dict(sorted(answer_scores.items(), key=lambda x: -x[1]))}")

# ──────────────────────────────────────────────────────────────────────────────
# Overall results
# ──────────────────────────────────────────────────────────────────────────────
n_problems = len(scored_results)
print("\n" + "=" * 70)
print("RESULTS SUMMARY")
print("=" * 70)
print(f"  Greedy (N=1):              {greedy_correct_count}/{n_problems} = {greedy_correct_count/n_problems:.1%}")
print(f"  Majority Vote (N=16):      {majority_correct}/{n_problems} = {majority_correct/n_problems:.1%}")
print(f"  Standard Best-of-N (N=16): {standard_correct}/{n_problems} = {standard_correct/n_problems:.1%}")
print(f"  Weighted Best-of-N (N=16): {weighted_correct}/{n_problems} = {weighted_correct/n_problems:.1%}")

# Save results
with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/bon_results.json", "w") as f:
    json.dump(results_summary, f, indent=2)
print("\nSaved detailed results to outputs/bon_results.json")

# ──────────────────────────────────────────────────────────────────────────────
# Compute Best-of-N at various N values (using the N=16 sample)
# ──────────────────────────────────────────────────────────────────────────────
print("\n" + "=" * 70)
print("ANALYSIS: How accuracy varies with N")
print("=" * 70)

import random
random.seed(42)

n_values = [1, 2, 4, 8, 16]
n_trials = 50  # Average over multiple random subsets for N < 16

accuracy_by_n = {}
for n in n_values:
    if n == 16:
        # Use all solutions
        correct = 0
        for s in scored_results:
            answer, _ = weighted_best_of_n(s["extracted_answers"], s["prm_scores"])
            if answer == s["answer"]:
                correct += 1
        acc = correct / n_problems
    else:
        # Subsample and average over trials
        trial_accs = []
        for trial in range(n_trials):
            correct = 0
            for s in scored_results:
                # Random subset of N solutions
                indices = random.sample(range(16), n)
                sub_answers = [s["extracted_answers"][j] for j in indices]
                sub_scores = [s["prm_scores"][j] for j in indices]
                answer, _ = weighted_best_of_n(sub_answers, sub_scores)
                if answer == s["answer"]:
                    correct += 1
            trial_accs.append(correct / n_problems)
        acc = sum(trial_accs) / len(trial_accs)

    accuracy_by_n[n] = acc
    print(f"  N={n:2d}: {acc:.1%}")

# Save accuracy-by-N for plotting
with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/accuracy_by_n.json", "w") as f:
    json.dump(accuracy_by_n, f, indent=2)

print("\nDone! Results saved. Run step4_analysis.py for plots.")