zeju-0727 commited on
Commit
03f925f
·
verified ·
1 Parent(s): bc901fd

Upload omegaprm.py

Browse files
Files changed (1) hide show
  1. omegaprm.py +787 -0
omegaprm.py ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import heapq
2
+ import math
3
+ import random
4
+ import re
5
+ import json
6
+ from typing import List, Tuple, Dict, Any, Optional
7
+ import itertools
8
+ from transformers import AutoTokenizer
9
+ import asyncio # New import added for async handling
10
+ from openai import AsyncOpenAI # Using AsyncOpenAI as client
11
+ import numpy as np
12
+
13
+
14
+ # Helper function to separate reasoning steps
15
+ def separate_steps(steps: List[str], mode: str = 'join') -> Any:
16
+ delimiter = "\n\n"
17
+ if mode == 'join':
18
+ if not isinstance(steps, list):
19
+ raise TypeError("For 'join' mode, 'steps' must be a list of strings.")
20
+ return delimiter.join(steps)
21
+ elif mode == 'split':
22
+ if not isinstance(steps, str):
23
+ raise TypeError("For 'split' mode, 'steps' must be a string.")
24
+ return steps.split(delimiter)
25
+ else:
26
+ raise ValueError("Mode should be either 'join' or 'split'.")
27
+
28
+
29
+ # def judge_ans(
30
+ # problem_str: str,
31
+ # extracted_groundtruth: str,
32
+ # output_list: List[str],
33
+ # v_list: List[float],
34
+ # aggration_mode: str,
35
+ # extract_answer_fn,
36
+ # judge_correct_fn,
37
+ # normalize=False,
38
+ # ):
39
+ # ans_list = [extract_answer_fn(txt) for txt in output_list]
40
+ # valid_ans_list, valid_v_list = [], []
41
+ # for i, ans in enumerate(ans_list):
42
+ # if ans != INVALID_ANS:
43
+ # valid_ans_list.append(ans)
44
+ # valid_v_list.append(v_list[i])
45
+ # if len(valid_ans_list) == 0:
46
+ # return 0
47
+
48
+ # if "orm" in aggration_mode and normalize:
49
+ # # score_normalization: this is only necessary for [-1, 1] values
50
+ # valid_v_list = np.array(valid_v_list)
51
+ # valid_v_list -= valid_v_list.min()
52
+ # valid_v_list /= valid_v_list.max() + 1e-3
53
+ # valid_v_list = valid_v_list.tolist()
54
+ # aggregated_ans = AGG_FN_MAP[aggration_mode](valid_ans_list, valid_v_list)
55
+
56
+ # return (
57
+ # 1 if judge_correct_fn(problem_str, extracted_groundtruth, aggregated_ans) else 0
58
+ # )
59
+
60
+
61
+
62
+
63
+
64
+ # Helper function to check correctness of a generated response
65
+ def check_correctness(generated_response: str, expected_answer: str) -> bool:
66
+ # sentences = re.split(
67
+ # r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', generated_response.strip()
68
+ # )
69
+ # last_sentence = sentences[-1] if sentences else ''
70
+ # return expected_answer.strip() in last_sentence.strip()
71
+ extract_answer_fn = MATH.extract_answer
72
+ judge_correct_fn = MATH.judge_correct
73
+ answer = extract_answer_fn(generated_response)
74
+ return (
75
+ 1 if judge_correct_fn("", expected_answer, answer) else 0
76
+ )
77
+
78
+
79
+ class ProcessRewardModel:
80
+ """
81
+ ProcessRewardModel encapsulates the reward inference process.
82
+
83
+ It utilizes a chat-based reward model (e.g., 'Llama3.1-8B-PRM-Mistral-Data')
84
+ to evaluate a sequence of reasoning steps. It iteratively sends messages to the model,
85
+ checking that each step receives a positive judgement (i.e., its completion starts with '+').
86
+ """
87
+
88
+ def __init__(self, client, model="Llama3.1-8B-PRM-Mistral-Data", temperature=0.0, max_tokens=1):
89
+ """
90
+ Initialize the ProcessRewardModel.
91
+
92
+ Parameters:
93
+ client: The chat client instance that provides a `chat.completions.create` method.
94
+ model (str): The model name to be used for generating reward completions.
95
+ temperature (float): Sampling temperature. Default is 0.0 for deterministic outcomes.
96
+ max_tokens (int): Maximum tokens to generate in the reward inference. Default is 1.
97
+ """
98
+ self.client = client
99
+ self.model = model
100
+ self.temperature = temperature
101
+ self.max_tokens = max_tokens
102
+
103
+ def evaluate(self, problem: str, steps: list, output_type: str = 'bool') -> bool:
104
+ """
105
+ Synchronously evaluate the process reward using asynchronous API calls.
106
+ This method wraps the asynchronous _async_evaluate call.
107
+
108
+ Parameters:
109
+ problem (str): The problem or question statement.
110
+ steps (List[str]): A list of reasoning steps.
111
+
112
+ Returns:
113
+ bool: True if all steps are positively judged, False otherwise.
114
+ """
115
+ return asyncio.run(self._async_evaluate(problem, steps, output_type))
116
+
117
+ async def _async_evaluate(self, problem: str, steps: list, output_type: str = 'bool') -> bool:
118
+ messages = []
119
+
120
+ # Merge every 5 steps into 1 step to reduce evaluation time
121
+ if 'deepseek' in self.model.lower():
122
+ merged_steps = []
123
+ current_merge = []
124
+ for step in steps:
125
+ current_merge.append(step)
126
+ if len(current_merge) == 6:
127
+ merged_steps.append("\n\n".join(current_merge))
128
+ current_merge = []
129
+ if current_merge: # Add any remaining steps
130
+ merged_steps.append("\n\n".join(current_merge))
131
+
132
+ steps = merged_steps
133
+ for sdx, step in enumerate(steps):
134
+ if sdx == 0:
135
+ messages.append({
136
+ 'role': 'user',
137
+ 'content': f"{problem}\n\n{step}"
138
+ })
139
+ else:
140
+ messages.append({
141
+ 'role': 'user',
142
+ 'content': step
143
+ })
144
+
145
+ completion = await self.client.chat.completions.create(
146
+ model=self.model,
147
+ messages=messages,
148
+ n=1,
149
+ temperature=self.temperature,
150
+ max_tokens=self.max_tokens,
151
+ )
152
+
153
+ response = completion.choices[0].message.content.strip().lower()
154
+ if not response.startswith('+'):
155
+ if output_type == 'bool':
156
+ return False
157
+ else:
158
+ return [1.0 if i < sdx else 0.0 for i in range(len(steps))]
159
+ messages.append({'role': 'assistant', 'content': '+'})
160
+
161
+ if output_type == 'bool':
162
+ return True
163
+ else:
164
+ return [1.0] * len(steps)
165
+
166
+ async def _async_evaluate_system2(self, problem: str, steps: list, output_type: str = 'bool') -> bool:
167
+ messages = []
168
+
169
+ # Merge every 5 steps into 1 step to reduce evaluation time
170
+
171
+ if 'deepseek' in self.model.lower():
172
+ merged_steps = []
173
+ current_merge = []
174
+ for step in steps:
175
+ current_merge.append(step)
176
+ if len(current_merge) == 6:
177
+ merged_steps.append("\n\n".join(current_merge))
178
+ current_merge = []
179
+ if current_merge: # Add any remaining steps
180
+ merged_steps.append("\n\n".join(current_merge))
181
+
182
+ steps = merged_steps
183
+ for sdx, step in enumerate(steps):
184
+
185
+ if sdx == 0:
186
+ messages.append({
187
+ 'role': 'user',
188
+ 'content': f"Problem: {problem}\n\nStep: {step}\n\nIs this step correct? You must answer with '+' for correct or '-' for incorrect in the end of your response."
189
+ })
190
+ else:
191
+ messages.append({
192
+ 'role': 'user',
193
+ 'content': f"Step: {step}\n\nIs this step correct? You must answer with '+' for correct or -' for incorrect in the end of your response."
194
+ })
195
+
196
+ completion = await self.client.chat.completions.create(
197
+ model=self.model,
198
+ messages=messages,
199
+ n=1,
200
+ temperature=self.temperature,
201
+ max_tokens=8192,
202
+ )
203
+ response = completion.choices[0].message.content
204
+
205
+ # print("DyVer Verification:", response)
206
+
207
+ # New negative checking logic
208
+ content = response.strip().lower()
209
+ last_words = ' '.join(content.split()[-3:]) # Last 3 words
210
+
211
+ judgment = any(
212
+ '+' in part and '-' not in part
213
+ for part in (
214
+ content[-5:],
215
+ last_words,
216
+ )
217
+ )
218
+
219
+ if not judgment:
220
+ return [1.0 if i < sdx else 0.0 for i in range(len(steps))]
221
+ messages.append({'role': 'assistant', 'content': '<think>\n\n</think> +'})
222
+ return [1.0] * len(steps)
223
+
224
+
225
+
226
+ class LanguageModel:
227
+ def __init__(self, client, model_name="/root/.cache/modelscope/hub/Qwen/Qwen2___5-Math-7B-Instruct",
228
+ max_new_tokens=512, temperature=0.7, top_p=0.9):
229
+ """
230
+ Initialize the LanguageModel for async OpenAI calls.
231
+ Removed the LLMService dependency and using async calls via openai.
232
+
233
+ Parameters:
234
+ - client: An instance of AsyncOpenAI passed externally.
235
+ - model_name (str): API model name to use.
236
+ - max_new_tokens (int): Maximum tokens for generation.
237
+ - temperature (float): Sampling temperature.
238
+ - top_p (float): Nucleus sampling probability.
239
+ """
240
+ self.model_name = model_name
241
+ self.max_new_tokens = max_new_tokens
242
+ self.temperature = temperature
243
+ self.top_p = top_p
244
+ self.default_prompt = (
245
+ "Please complete the answer for the question based on the given steps without generating existing steps again, "
246
+ "and separate your following steps using \n\n.\n\n"
247
+ )
248
+ # Retain tokenizer for chat template operations elsewhere.
249
+ try:
250
+ self.tokenizer = AutoTokenizer.from_pretrained(f"/root/{model_name}")
251
+ except Exception as e:
252
+ print(f"Error loading tokenizer: {e}")
253
+ self.tokenizer = None
254
+ # Use the external AsyncOpenAI client.
255
+ self.async_client = client
256
+
257
+ async def generate_rollout(self, state_prefix: str, num_copies: int) -> List[str]:
258
+ """
259
+ Asynchronously generate responses using OpenAI's ChatCompletion API.
260
+
261
+ Parameters:
262
+ - state_prefix (str): The current solution prefix.
263
+ - num_copies (int): The number of response copies to generate.
264
+
265
+ Returns:
266
+ - List[str]: A list of generated responses.
267
+ """
268
+ response = await self.async_client.completions.create(
269
+ model=self.model_name,
270
+ prompt=state_prefix,
271
+ max_tokens=self.max_new_tokens,
272
+ temperature=self.temperature,
273
+ top_p=self.top_p,
274
+ n=num_copies,
275
+ )
276
+ return [choice.text for choice in response.choices]
277
+
278
+ def update_prompt(self, new_prompt: str):
279
+ """
280
+ Update the default prompt if necessary.
281
+
282
+ Parameters:
283
+ - new_prompt (str): The new prompt template.
284
+ """
285
+ self.default_prompt = new_prompt
286
+
287
+ def evaluate_correctness(self, response: str, expected_answer: str) -> bool:
288
+ """
289
+ Check if the generated solution matches the expected answer.
290
+
291
+ Parameters:
292
+ - response (str): The complete generated response.
293
+ - expected_answer (str): The expected answer to compare with.
294
+
295
+ Returns:
296
+ - bool: True if the expected answer is in the final part of the solution.
297
+ """
298
+ return check_correctness(response, expected_answer)
299
+
300
+
301
+ # Define the State class
302
+ class State:
303
+ def __init__(self, solution_prefix: str, parent: Optional['State'] = None):
304
+ self.solution_prefix = solution_prefix # Solution prefix as a single string
305
+ self.parent = parent # Reference to the parent state
306
+ self.N = 0 # Visit count (number of times selected)
307
+ self.total_rollouts = 0 # Total number of rollouts generated from this state
308
+ self.correct_rollouts = 0 # Number of correct rollouts
309
+ self.MC: Optional[float] = None # Monte Carlo estimation (c/k)
310
+ self.Q: Dict[str, float] = {} # Q(s, r): estimated value for each rollout
311
+ self.R: List[str] = [] # Set of all rollouts from this state
312
+ self.incorrect_rollouts: List[str] = [] # List of incorrect rollouts
313
+ self.children: List['State'] = [] # List of child states
314
+
315
+ def add_rollout(self, rollout: str):
316
+ self.R.append(rollout)
317
+
318
+ def add_incorrect_rollout(self, rollout: str):
319
+ if rollout not in self.incorrect_rollouts:
320
+ self.incorrect_rollouts.append(rollout)
321
+
322
+ def get_full_solution(self) -> str:
323
+ # Return the complete solution from the root to this state
324
+ if self.parent:
325
+ return self.parent.get_full_solution() + '\n\n' + self.solution_prefix
326
+ else:
327
+ return self.solution_prefix
328
+
329
+ def get_new_text(self) -> str:
330
+ """
331
+ Return the new text added at this node compared to the parent.
332
+ """
333
+ if self.parent:
334
+ parent_text = self.parent.solution_prefix
335
+ new_text = self.solution_prefix[len(parent_text):].strip()
336
+ return new_text
337
+ else:
338
+ # Root node (the question)
339
+ return self.solution_prefix.strip()
340
+
341
+ def get_text_with_labels(self) -> Dict[str, Any]:
342
+ """
343
+ Return a nested dictionary where each node contains:
344
+ - 'text': The new text at this node.
345
+ - 'mc_value': The MC value at this node.
346
+ - 'children': A list of child nodes with the same structure.
347
+ """
348
+ data = {
349
+ 'text': self.get_new_text(),
350
+ 'mc_value': self.MC,
351
+ 'children': [child.get_text_with_labels() for child in self.children]
352
+ }
353
+ return data
354
+
355
+
356
+ # Define the Search Tree class
357
+ class SearchTree:
358
+ def __init__(self):
359
+ self.root: Optional[State] = None
360
+ self.nodes: List[State] = [] # List of all states
361
+
362
+ def add_state(self, state: State):
363
+ self.nodes.append(state)
364
+
365
+ # Define the Candidate Pool as a priority queue with update capability
366
+ class CandidatePool:
367
+ def __init__(self):
368
+ self.heap: List[Tuple[float, int]] = [] # Heap of (-priority, unique_id)
369
+ self.entry_finder: Dict[int, Tuple[float, int]] = {} # Maps unique_id to (-priority, unique_id)
370
+ self.counter = itertools.count() # Unique sequence count
371
+ self.id_to_rollout: Dict[int, Tuple[State, str]] = {} # Maps unique_id to (state, rollout)
372
+ self.latest_id_per_rollout: Dict[Tuple[int, str], int] = {} # Maps (state_id, rollout) to unique_id
373
+
374
+ def add_or_update(self, state: State, rollout: str, priority: float):
375
+ """
376
+ Add a new rollout or update the priority of an existing rollout.
377
+
378
+ Parameters:
379
+ - state (State): The state associated with the rollout.
380
+ - rollout (str): The rollout string.
381
+ - priority (float): The new priority score.
382
+ """
383
+ state_id = id(state) # Unique identifier for the state object
384
+ rollout_key = (state_id, rollout)
385
+
386
+ # Check if the rollout already exists in the pool
387
+ if rollout_key in self.latest_id_per_rollout:
388
+ # Previous unique_id exists; it is now outdated
389
+ old_unique_id = self.latest_id_per_rollout[rollout_key]
390
+ # Mark the old entry as invalid by removing it from entry_finder
391
+ if old_unique_id in self.entry_finder:
392
+ del self.entry_finder[old_unique_id]
393
+ del self.id_to_rollout[old_unique_id]
394
+
395
+ # Assign a new unique_id for the updated rollout
396
+ unique_id = next(self.counter)
397
+ self.latest_id_per_rollout[rollout_key] = unique_id
398
+
399
+ # Add the new entry to the heap and mappings
400
+ heapq.heappush(self.heap, (-priority, unique_id)) # Max-heap using negative priority
401
+ self.entry_finder[unique_id] = (-priority, unique_id)
402
+ self.id_to_rollout[unique_id] = (state, rollout)
403
+
404
+ def pop(self) -> Tuple[Optional[State], Optional[str]]:
405
+ """
406
+ Pop the rollout with the highest priority.
407
+
408
+ Returns:
409
+ - Tuple[Optional[State], Optional[str]]: The state and rollout string, or (None, None) if empty.
410
+ """
411
+ while self.heap:
412
+ neg_priority, unique_id = heapq.heappop(self.heap)
413
+ # Check if this unique_id is still valid
414
+ if unique_id in self.entry_finder:
415
+ # Valid entry
416
+ state, rollout = self.id_to_rollout.pop(unique_id)
417
+ del self.entry_finder[unique_id]
418
+ # Remove from latest_id_per_rollout
419
+ state_id = id(state)
420
+ rollout_key = (state_id, rollout)
421
+ if self.latest_id_per_rollout.get(rollout_key) == unique_id:
422
+ del self.latest_id_per_rollout[rollout_key]
423
+ return state, rollout
424
+ # Else, outdated entry; skip
425
+ return None, None
426
+
427
+ def is_empty(self) -> bool:
428
+ return not self.entry_finder
429
+
430
+ # Define the OmegaPRM algorithm
431
+ class OmegaPRM:
432
+ def __init__(self, LM: LanguageModel, reward_model, c_puct: float, alpha: float, beta: float, L: int, k: int, N: int,
433
+ rollout_budget: int, save_data_tree: bool):
434
+ """
435
+ Initialize the OmegaPRM algorithm.
436
+
437
+ Parameters:
438
+ LM (LanguageModel): The language model instance.
439
+ reward_model: An instance of ProcessRewardModel to evaluate solution correctness.
440
+ c_puct (float): Exploration constant.
441
+ alpha (float): Weight for MC(s).
442
+ beta (float): Length penalty.
443
+ L (int): Maximum solution length.
444
+ k (int): Number of rollouts for Monte Carlo estimation.
445
+ N (int): Maximum search count.
446
+ rollout_budget (int): Total rollout budget.
447
+ save_data_tree (bool): Whether to save and return the data tree.
448
+ """
449
+ self.LM = LM
450
+ self.reward_model = reward_model
451
+ self.expected_answer = None
452
+ self.c_puct = c_puct
453
+ self.alpha = alpha
454
+ self.beta = beta
455
+ self.L = L
456
+ self.k = k
457
+ self.N = N
458
+ self.rollout_budget = rollout_budget
459
+ self.save_data_tree = save_data_tree
460
+
461
+ self.T = SearchTree()
462
+ self.C = CandidatePool()
463
+
464
+ self.n = 0
465
+ self.total_rollouts = 0
466
+
467
+ def reset(self):
468
+ """Reset internal state variables to prepare for a fresh run."""
469
+ self.expected_answer = None
470
+ self.T = SearchTree() # Reset search tree
471
+ self.C = CandidatePool() # Reset candidate pool
472
+ self.n = 0
473
+ self.total_rollouts = 0
474
+ self.collected_data = [] # Clear collected data
475
+
476
+ async def monte_carlo_estimation(self, state: State):
477
+ """
478
+ Perform Monte Carlo estimation for state by generating k rollouts
479
+ and computing MC(s) = c / k, where c is the number of correct rollouts.
480
+ """
481
+ c = 0 # Correct rollouts count
482
+ incorrect_rollouts = []
483
+ correct_rollouts = []
484
+ batct_rollouts = await self.LM.generate_rollout(state.solution_prefix, self.k)
485
+
486
+ # Increment visit count of selected state
487
+ state.N += 1
488
+
489
+ for i, rollout in enumerate(batct_rollouts):
490
+ # Increment number of total rollouts
491
+ self.total_rollouts += 1
492
+
493
+ # Generate rollout r_i
494
+ state.add_rollout(rollout)
495
+
496
+ # Evaluate correctness of final answer in rollout using the reward model.
497
+ full_solution = (state.solution_prefix + '\n\n' + rollout).strip() if state.solution_prefix else rollout
498
+ steps = separate_steps(full_solution, mode='split')
499
+ # If all steps receive a positive judgment, evaluate returns -1.
500
+ is_correct = await self.reward_model._async_evaluate(self.problem, steps)
501
+
502
+ if is_correct:
503
+ c += 1
504
+ correct_rollouts.append(rollout)
505
+ else:
506
+ incorrect_rollouts.append(rollout)
507
+ state.add_incorrect_rollout(rollout) # Track incorrect rollouts
508
+
509
+ # Update total rollouts and correct rollouts
510
+ state.total_rollouts += self.k
511
+ state.correct_rollouts += c
512
+ state.MC = state.correct_rollouts / state.total_rollouts if state.total_rollouts > 0 else 0
513
+
514
+ if state.MC == 1.0:
515
+ # Add all correct rollouts to the tree as new states
516
+ for rollout in correct_rollouts:
517
+ self.add_correct_rollout_to_tree(state, rollout)
518
+ elif state.MC == 0.0:
519
+ # State is incorrect; no further action
520
+ for rollout in incorrect_rollouts:
521
+ self.add_incorrect_rollout_to_tree(state, rollout)
522
+ return
523
+ else:
524
+ # 0 < MC(s) < 1.0
525
+ # Add correct rollouts to the tree
526
+ for rollout in correct_rollouts:
527
+ self.add_correct_rollout_to_tree(state, rollout)
528
+ # Add incorrect rollouts to candidate pool with updated priorities
529
+ for rollout in incorrect_rollouts:
530
+ priority = self.compute_selection_score(state, rollout)
531
+ self.C.add_or_update(state, rollout, priority)
532
+
533
+ async def run(self, question: str, answer: str) -> List:
534
+ """
535
+ Execute the OmegaPRM algorithm.
536
+
537
+ Parameters:
538
+ - question (str): The question to generate solutions for.
539
+
540
+ Returns:
541
+ - Collected data: List of dictionaries.
542
+ """
543
+ self.reset()
544
+ self.problem = question # Store the original question for reward evaluation
545
+
546
+ print(f"Running OmegaPRM for question: '{question}'\n")
547
+ # Initialization
548
+ if self.LM.tokenizer is not None:
549
+ question_tamplated = self.LM.tokenizer.apply_chat_template(
550
+ [{"role": "user", "content": question}],
551
+ tokenize=False,
552
+ add_special_tokens=False,
553
+ add_generation_prompt=True
554
+ )
555
+ else:
556
+ question_tamplated = question
557
+ initial_state = State(solution_prefix=question_tamplated, parent=None)
558
+ self.expected_answer = answer
559
+ self.T.root = initial_state
560
+ self.T.add_state(initial_state)
561
+ self.n = 0
562
+
563
+ # Monte Carlo Estimation for initial_state
564
+ await self.monte_carlo_estimation(initial_state)
565
+
566
+ # Main loop
567
+ while self.n < self.N and self.total_rollouts < self.rollout_budget and not self.C.is_empty():
568
+ # Selection Phase
569
+ selected_state, selected_rollout = self.selection_phase()
570
+ if selected_state is None or selected_rollout is None:
571
+ break
572
+
573
+ await self.expansion_phase_binary_search(selected_state, selected_rollout)
574
+
575
+ # Maintenance Phase
576
+ self.maintenance_phase(selected_state)
577
+
578
+ # Increment search count
579
+ self.n += 1
580
+
581
+ if self.save_data_tree:
582
+ data = self.collect_tree_structure()
583
+ else:
584
+ data = self.collect_solution_prefixes()
585
+ return data
586
+
587
+ def compute_Q(self, state: State, rollout: str) -> float:
588
+ """
589
+ Compute Q(s, r) = alpha^{1 - MC(s)} * beta^{len(r)/L}, where len(r) is based on word count.
590
+ """
591
+ # Count words in the rollout
592
+ word_count = len(rollout.split())
593
+ length_penalty = word_count / self.L
594
+ Q_value = (self.alpha ** (1 - state.MC)) * (self.beta ** length_penalty)
595
+ return Q_value
596
+
597
+ def compute_U(self, state: State) -> float:
598
+ """
599
+ Compute U(s) = c_puct * sqrt(sum_{s'} N(s')) / (1 + N(s))
600
+ """
601
+ N_total = sum(s.N for s in self.T.nodes)
602
+ if N_total == 0:
603
+ N_total = 1 # Prevent division by zero
604
+ U_s = self.c_puct * (math.sqrt(N_total)) / (1 + state.N)
605
+ return U_s
606
+
607
+ def compute_selection_score(self, state: State, rollout: str) -> float:
608
+ """
609
+ Compute selection score: Score(s, r) = Q(s, r) + U(s)
610
+ """
611
+ Q_s_r = self.compute_Q(state, rollout)
612
+ U_s = self.compute_U(state)
613
+ score = Q_s_r + U_s
614
+ return score
615
+
616
+ def selection_phase(self) -> Tuple[Optional[State], Optional[str]]:
617
+ """
618
+ Select (state, rollout) with the highest score from candidate pool C.
619
+ """
620
+ selected_state, selected_rollout = self.C.pop()
621
+ return selected_state, selected_rollout
622
+
623
+ def add_correct_rollout_to_tree(self, parent_state: State, rollout: str):
624
+ """
625
+ Add the correct rollout to the tree as a child of parent_state.
626
+ """
627
+ new_solution_prefix = (parent_state.solution_prefix + '\n\n' + rollout).strip() if parent_state.solution_prefix else rollout
628
+ new_state = State(solution_prefix=new_solution_prefix, parent=parent_state)
629
+ new_state.MC = 1.0 # Since the rollout is correct
630
+ new_state.total_rollouts = 0
631
+ new_state.correct_rollouts = 0
632
+ self.T.add_state(new_state)
633
+ parent_state.children.append(new_state) # Add to parent's children
634
+
635
+ def add_incorrect_rollout_to_tree(self, parent_state: State, rollout: str):
636
+ """
637
+ Add the incorrect rollout to the tree as a child of parent_state.
638
+
639
+ Parameters:
640
+ - parent_state (State): The state from which the rollout was selected.
641
+ - rollout (str): The incorrect rollout string.
642
+ """
643
+ new_solution_prefix = (parent_state.solution_prefix + '\n\n' + rollout).strip() if parent_state.solution_prefix else rollout
644
+ new_state = State(solution_prefix=new_solution_prefix, parent=parent_state)
645
+ new_state.MC = 0.0 # Since the rollout is incorrect
646
+ new_state.total_rollouts = 0
647
+ new_state.correct_rollouts = 0
648
+ self.T.add_state(new_state)
649
+ parent_state.children.append(new_state) # Add to parent's children
650
+
651
+ async def binary_search_incorrect_step(self, s_ast: State, steps: List[str], left: int, right: int):
652
+ """
653
+ Recursively perform binary search to find all incorrect steps in the rollout.
654
+ """
655
+ if left > right:
656
+ return
657
+
658
+ mid = (left + right) // 2
659
+ new_steps = steps[left:mid + 1]
660
+ if new_steps:
661
+ prefix_solution = s_ast.solution_prefix + '\n\n' + separate_steps(new_steps, mode='join')
662
+ else:
663
+ prefix_solution = s_ast.solution_prefix
664
+ # Create new state s_new
665
+ s_new = State(solution_prefix=prefix_solution.strip(), parent=s_ast)
666
+ self.T.add_state(s_new)
667
+ s_ast.children.append(s_new)
668
+
669
+ # Perform Monte Carlo estimation for s_new
670
+ await self.monte_carlo_estimation(s_new)
671
+
672
+ if s_new.MC == 0:
673
+ # Found incorrect step; continue searching in the left half to find earlier incorrect steps
674
+ await self.binary_search_incorrect_step(s_ast, steps, left, mid - 1)
675
+ else:
676
+ # Steps up to mid are correct; continue searching in the right half
677
+ await self.binary_search_incorrect_step(s_new, steps, mid + 1, right)
678
+
679
+ async def expansion_phase_binary_search(self, parent_state: State, rollout: str):
680
+ """
681
+ Expansion phase that adds the rollout as a new state and performs Monte Carlo estimation
682
+ using Binary Search to efficiently find the correct rollout.
683
+ """
684
+ # Separate the rollout into individual steps
685
+ steps = separate_steps(rollout, mode='split')
686
+
687
+ # Perform binary search to find incorrect steps
688
+ await self.binary_search_incorrect_step(parent_state, steps, 0, len(steps) - 1)
689
+
690
+ def maintenance_phase(self, state: State):
691
+ """
692
+ Update statistics and candidate pool for all incorrect rollouts associated with the state.
693
+
694
+ Parameters:
695
+ - state (State): The state whose incorrect rollouts need to be updated.
696
+ """
697
+
698
+ # Iterate through all incorrect rollouts of the state
699
+ for rollout in state.incorrect_rollouts:
700
+ # Since we've already determined these rollouts are incorrect, no need to re-evaluate correctness
701
+
702
+ priority = self.compute_selection_score(state, rollout)
703
+ # Update the candidate pool with the new priority
704
+ self.C.add_or_update(state, rollout, priority)
705
+ # print(f"Updated Incorrect Rollout: '{rollout}' with new priority: {priority:.4f}")
706
+
707
+ # print("Maintenance Phase Completed.\n")
708
+
709
+ def collect_solution_prefixes(self) -> List[Dict[str, Any]]:
710
+ """
711
+ Collect all solution prefixes and their corresponding MC values from the search tree.
712
+
713
+ Returns:
714
+ List[Dict[str, Any]]: A list of dictionaries containing solution prefixes and their MC values.
715
+ """
716
+ collected_data = []
717
+ for node in self.T.nodes:
718
+ solution_prefix = node.solution_prefix
719
+ mc_value = node.MC
720
+ collected_data.append({
721
+ "solution_prefix": solution_prefix,
722
+ "mc_value": mc_value
723
+ })
724
+ return collected_data
725
+
726
+ def collect_tree_structure(self) -> Dict[str, Any]:
727
+ """
728
+ Collect the tree structure starting from the root.
729
+
730
+ Returns:
731
+ Dict[str, Any]: A nested dictionary representing the tree structure.
732
+ """
733
+ if self.T.root:
734
+ tree_data = self.T.root.get_text_with_labels()
735
+ return tree_data
736
+ return {}
737
+
738
+
739
+ # Example usage
740
+ if __name__ == "__main__":
741
+ # Initialize the Language Model's AsyncOpenAI client for LM.
742
+ from openai import AsyncOpenAI
743
+ lm_client = AsyncOpenAI(
744
+ base_url="http://localhost:8000/v1",
745
+ api_key="token-abc123",
746
+ )
747
+
748
+ LM = LanguageModel(
749
+ client=lm_client,
750
+ max_new_tokens=4096,
751
+ temperature=0.7,
752
+ top_p=0.9,
753
+ model_name="DeepSeek-R1-Distill-Qwen-14B"
754
+ )
755
+
756
+ # Define the question and expected answer
757
+ question = "Melinda will roll two standard six-sided dice and make a two-digit number with the two numbers she rolls. For example, if she rolls a 6 and a 3, she can either form 36 or 63. What is the probability that she will be able to make an integer between 10 and 20, inclusive? Express your answer as a common fraction."
758
+ expected_answer = "\\frac{11}{36}"
759
+
760
+ client = AsyncOpenAI(
761
+ base_url="http://localhost:8001/v1",
762
+ api_key="token-abc123",
763
+ ) # This is a placeholder; ensure client supports sync chat.completions.create
764
+ reward_model = ProcessRewardModel(client, model="deepseek-14b-prm-filtered-balance-full", temperature=0.0, max_tokens=1)
765
+
766
+ # Initialize OmegaPRM with parameters and the reward model instance
767
+ omega_prm = OmegaPRM(
768
+ LM=LM,
769
+ reward_model=reward_model,
770
+ c_puct=0.125,
771
+ alpha=0.5,
772
+ beta=0.9,
773
+ L=500,
774
+ k=8,
775
+ N=10,
776
+ rollout_budget=20,
777
+ save_data_tree=True,
778
+ )
779
+
780
+ # Run the OmegaPRM algorithm
781
+ collected_data = asyncio.run(omega_prm.run(question, expected_answer))
782
+
783
+ # Save the collected solutions to a JSON file
784
+ with open("collected_solutions2.json", "w") as f:
785
+ json.dump(collected_data, f, indent=4)
786
+
787
+