cmpatino HF Staff commited on
Commit
02efe1b
·
verified ·
1 Parent(s): 1ce1b4f

Upload code/step2_sample_and_score.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/step2_sample_and_score.py +212 -0
code/step2_sample_and_score.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Step 2: Sample N=16 solutions per problem and score with Skywork PRM.
3
+
4
+ This script:
5
+ 1. Loads the filtered problems from Step 1
6
+ 2. Generates N=16 solutions per problem using temperature sampling
7
+ 3. Loads the Skywork-o1-Open-PRM and scores each solution (last step prediction)
8
+ 4. Saves all solutions + scores for the Best-of-N computation
9
+
10
+ The Skywork PRM is loaded using its custom PRM_MODEL class, which wraps
11
+ AutoModelForCausalLM with a ValueHead (linear projection to scalar).
12
+ The model outputs a sigmoid-normalized score in [0,1] at each step boundary.
13
+
14
+ Co-authored with Claude (Anthropic). I can explain all code logic.
15
+ """
16
+
17
+ import json
18
+ import os
19
+ import sys
20
+ import torch
21
+ import subprocess
22
+ from transformers import AutoTokenizer, AutoModelForCausalLM
23
+ from typing import Optional
24
+
25
+ # ──────────────────────────────────────────────────────────────────────────────
26
+ # Helper functions
27
+ # ──────────────────────────────────────────────────────────────────────────────
28
+ def extract_boxed_solution(text: str) -> Optional[str]:
29
+ """Extract content of the last \\boxed{} in text."""
30
+ try:
31
+ start_index = text.rindex("\\boxed{")
32
+ content_start = start_index + 7
33
+ bracket_count = 1
34
+ current_pos = content_start
35
+ while bracket_count > 0 and current_pos < len(text):
36
+ if text[current_pos] == "{":
37
+ bracket_count += 1
38
+ elif text[current_pos] == "}":
39
+ bracket_count -= 1
40
+ current_pos += 1
41
+ if bracket_count == 0:
42
+ return text[content_start : current_pos - 1].strip()
43
+ return None
44
+ except (ValueError, Exception):
45
+ return None
46
+
47
+
48
+ # ──────────────────────────────────────────────────────────────────────────────
49
+ # Load filtered problems
50
+ # ──────────────────────────────────────────────────────────────────────────────
51
+ print("=" * 70)
52
+ print("STEP 2a: Loading problems and generating N=16 solutions per problem")
53
+ print("=" * 70)
54
+
55
+ with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/filtered_problems.json") as f:
56
+ problems_data = json.load(f)
57
+ print(f"Loaded {len(problems_data)} problems")
58
+
59
+ # ──────────────────────────────────────────────────────────────────────────────
60
+ # Generate N=16 solutions per problem with temperature sampling
61
+ # ──────────────────────────────────────────────────────────────────────────────
62
+ MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
63
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
64
+ model = AutoModelForCausalLM.from_pretrained(
65
+ MODEL_ID,
66
+ torch_dtype=torch.bfloat16,
67
+ device_map="auto",
68
+ )
69
+
70
+ SYSTEM_PROMPT = (
71
+ "You are a helpful math assistant. Solve the problem step by step, "
72
+ "showing your reasoning clearly. Put your final answer inside "
73
+ "\\boxed{answer} at the end of your solution."
74
+ )
75
+
76
+ N = 16 # Number of solutions per problem
77
+ TEMPERATURE = 0.7 # Sampling temperature — balances diversity vs quality
78
+
79
+ all_results = []
80
+ for i, p in enumerate(problems_data):
81
+ print(f"\n Problem {i+1}/{len(problems_data)}: {p['unique_id']} (Level {p['level']})")
82
+
83
+ messages = [
84
+ {"role": "system", "content": SYSTEM_PROMPT},
85
+ {"role": "user", "content": p["problem"]},
86
+ ]
87
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
88
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
89
+
90
+ solutions = []
91
+ for j in range(N):
92
+ with torch.no_grad():
93
+ output = model.generate(
94
+ **inputs,
95
+ max_new_tokens=2048,
96
+ do_sample=True,
97
+ temperature=TEMPERATURE,
98
+ top_p=0.95,
99
+ )
100
+ generated = output[0][inputs["input_ids"].shape[1]:]
101
+ solution_text = tokenizer.decode(generated, skip_special_tokens=True)
102
+ solutions.append(solution_text)
103
+
104
+ if (j + 1) % 4 == 0:
105
+ print(f" Generated {j+1}/{N} solutions")
106
+
107
+ result = {**p, "sampled_solutions": solutions}
108
+ all_results.append(result)
109
+
110
+ # Save solutions before scoring (in case PRM loading takes time or fails)
111
+ with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/sampled_solutions.json", "w") as f:
112
+ json.dump(all_results, f, indent=2)
113
+ print(f"\nSaved {N} solutions per problem to outputs/sampled_solutions.json")
114
+
115
+ # Free LLM memory before loading PRM
116
+ del model
117
+ torch.cuda.empty_cache()
118
+ print("Freed LLM memory.")
119
+
120
+ # ──────────────────────────────────────────────────────────────────────────────
121
+ # Score solutions with Skywork PRM
122
+ # ──────────────────────────────────────────────────────────────────────────────
123
+ print("\n" + "=" * 70)
124
+ print("STEP 2b: Scoring solutions with Skywork-o1-Open-PRM")
125
+ print("=" * 70)
126
+
127
+ # Clone the Skywork PRM inference repo for the custom model class
128
+ PRM_REPO_PATH = "/Users/cmpatino/Projects/ml-intern/exercise/skywork-o1-prm-inference"
129
+ if not os.path.exists(PRM_REPO_PATH):
130
+ print("Cloning Skywork PRM inference repo...")
131
+ subprocess.run(
132
+ ["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git", PRM_REPO_PATH],
133
+ check=True,
134
+ )
135
+ sys.path.insert(0, PRM_REPO_PATH)
136
+
137
+ from model_utils.prm_model import PRM_MODEL
138
+ from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
139
+
140
+ PRM_MODEL_ID = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
141
+
142
+ prm_tokenizer = AutoTokenizer.from_pretrained(PRM_MODEL_ID, trust_remote_code=True)
143
+ prm_model = PRM_MODEL.from_pretrained(PRM_MODEL_ID, device_map="auto").eval()
144
+
145
+ print("PRM model loaded successfully.")
146
+
147
+
148
+ def score_solution(problem: str, solution: str) -> list[float]:
149
+ """
150
+ Score a single solution using the PRM.
151
+
152
+ Returns a list of per-step scores (sigmoid-normalized, [0,1]).
153
+ The last element is the 'last step prediction' — our final reward.
154
+
155
+ The PRM splits the solution by newlines (\n), and assigns a score
156
+ at the end of each step. These scores represent the model's estimate
157
+ of correctness probability at each reasoning step.
158
+ """
159
+ input_ids, steps, reward_flags = prepare_input(problem, solution, prm_tokenizer, step_token="\n")
160
+
161
+ # Prepare batch of size 1
162
+ input_ids_t, attention_mask_t, reward_flags_t = prepare_batch_input_for_model(
163
+ [input_ids], [reward_flags], prm_tokenizer.pad_token_id
164
+ )
165
+
166
+ # Move to model device
167
+ device = next(prm_model.parameters()).device
168
+ input_ids_t = input_ids_t.to(device)
169
+ attention_mask_t = attention_mask_t.to(device)
170
+ reward_flags_t = reward_flags_t.to(device)
171
+
172
+ with torch.no_grad():
173
+ # return_probs=True applies sigmoid internally
174
+ _, _, rewards = prm_model(
175
+ input_ids=input_ids_t,
176
+ attention_mask=attention_mask_t,
177
+ return_probs=True,
178
+ )
179
+
180
+ step_rewards = derive_step_rewards(rewards, reward_flags_t)
181
+ return step_rewards[0] # Return the single sample's step scores
182
+
183
+
184
+ # Score all solutions
185
+ print("\nScoring all solutions...")
186
+ for i, result in enumerate(all_results):
187
+ print(f"\n Scoring problem {i+1}/{len(all_results)}: {result['unique_id']}")
188
+ scores = []
189
+ extracted_answers = []
190
+
191
+ for j, solution in enumerate(result["sampled_solutions"]):
192
+ # Get PRM score
193
+ step_scores = score_solution(result["problem"], solution)
194
+ # Use last step prediction as the final reward (per DeepMind Appendix E)
195
+ final_score = step_scores[-1] if step_scores else 0.0
196
+ scores.append(final_score)
197
+
198
+ # Extract the final answer from \boxed{}
199
+ answer = extract_boxed_solution(solution)
200
+ extracted_answers.append(answer)
201
+
202
+ if (j + 1) % 4 == 0:
203
+ print(f" Scored {j+1}/{N} solutions (last score: {final_score:.4f})")
204
+
205
+ result["prm_scores"] = scores
206
+ result["extracted_answers"] = extracted_answers
207
+
208
+ # Save scored results
209
+ with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/scored_results.json", "w") as f:
210
+ json.dump(all_results, f, indent=2)
211
+ print("\nSaved scored results to outputs/scored_results.json")
212
+ print("Ready for Step 3 (Best-of-N computation).")