| ORMs are trained to predict the correctness of the whole solution on the position of "\<eos\>". | |
| But they are actually trained to forcast the correctness of the whole solution on each token (i.e., token-level loss). | |
| Usage: | |
| ```python | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| model_name = "ScalableMath/llemma-7b-orm-prm800k-level-1to3-hf" | |
| model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") | |
| tokenizer = AutoTokenizer.from_pretrained("EleutherAI/llemma_7b") | |
| qa_example = """# Question | |
| Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$ | |
| # Solution | |
| To convert from rectangular to polar coordinates, I need to use the formulas $r = \sqrt{x^2 + y^2}$ and $\theta = \tan^{-1}(y/x).$ | |
| In this case, $x = 0$ and $y = 3,$ so I can plug them into the formulas. | |
| For $r,$ I get $r = \sqrt{0^2 + 3^2} = \sqrt{9} = 3.$ | |
| For $\theta,$ I get $\theta = \tan^{-1}(3/0).$ | |
| This is undefined, since the tangent function is not defined at $0.$ | |
| However, I can use the fact that the point $(0,3)$ lies on the positive $y$-axis, which has an angle of $\pi/2$ radians or $90^\circ.$ | |
| Therefore, I can choose any angle in the range $(0,\pi/2)$ as the value of $\theta.$ | |
| I will choose $\theta = \pi/2,$ since it is the simplest and most natural choice. | |
| Therefore, the polar coordinates of the point $(0,3)$ are $(3,\pi/2).$ | |
| # Answer | |
| (3,\pi/2)""" | |
| begin_solution_tokens = tokenizer.encode("\n\n# Solution", add_special_tokens=False)[1:] | |
| scoring_tokens = tokenizer.encode("\n\n", add_special_tokens=False)[1:] | |
| eos_token = tokenizer.eos_token_id | |
| input_ids = tokenizer.encode(qa_example) | |
| begin_solution_flag = False | |
| candidate_positions = [] | |
| for start_idx in range(len(input_ids)): | |
| if tuple(input_ids[start_idx:start_idx+len(begin_solution_tokens)]) == tuple(begin_solution_tokens): | |
| begin_solution_flag = True | |
| if begin_solution_flag and tuple(input_ids[start_idx:start_idx+len(scoring_tokens)]) == tuple(scoring_tokens): | |
| candidate_positions.append(start_idx) | |
| if input_ids[start_idx] == eos_token: | |
| candidate_positions.append(start_idx) | |
| break | |
| # maybe delete the first and the second to last candidate_positions | |
| # because they are "\n\n" after "# Solution" and after "# Answer" | |
| del candidate_positions[0] | |
| del candidate_positions[-2] | |
| input_tensor = torch.tensor([input_ids]) | |
| candidate_positions = torch.tensor(candidate_positions) | |
| with torch.no_grad(): | |
| logits = model(input_tensor).logits | |
| scores =logits.mean(dim=-1) | |
| step_scores = scores[0][candidate_positions] | |
| step_probs = torch.sigmoid(step_scores) | |
| print(step_probs) | |
| # only the last logprob is orm's output | |
| # tensor([0.4531, 0.3882, 0.3748, 0.4785, 0.4087, 0.3166, 0.3040, 0.2295, 0.2628, 0.2568]) | |
| ``` |