Create README.md
Browse files
README.md
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ORMs are trained to predict the correctness of the whole solution on the position of "\<eos\>".
|
| 2 |
+
But they are actually trained to forcast the correctness of the whole solution on each token (i.e., token-level loss).
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
|
| 6 |
+
```python
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 9 |
+
|
| 10 |
+
model_name = "ScalableMath/llemma-7b-orm-prm800k-level-1to3-hf"
|
| 11 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
|
| 12 |
+
|
| 13 |
+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/llemma_7b")
|
| 14 |
+
|
| 15 |
+
qa_example = """# Question
|
| 16 |
+
|
| 17 |
+
Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$
|
| 18 |
+
|
| 19 |
+
# Solution
|
| 20 |
+
|
| 21 |
+
To convert from rectangular to polar coordinates, I need to use the formulas $r = \sqrt{x^2 + y^2}$ and $\theta = \tan^{-1}(y/x).$
|
| 22 |
+
|
| 23 |
+
In this case, $x = 0$ and $y = 3,$ so I can plug them into the formulas.
|
| 24 |
+
|
| 25 |
+
For $r,$ I get $r = \sqrt{0^2 + 3^2} = \sqrt{9} = 3.$
|
| 26 |
+
|
| 27 |
+
For $\theta,$ I get $\theta = \tan^{-1}(3/0).$
|
| 28 |
+
|
| 29 |
+
This is undefined, since the tangent function is not defined at $0.$
|
| 30 |
+
|
| 31 |
+
However, I can use the fact that the point $(0,3)$ lies on the positive $y$-axis, which has an angle of $\pi/2$ radians or $90^\circ.$
|
| 32 |
+
|
| 33 |
+
Therefore, I can choose any angle in the range $(0,\pi/2)$ as the value of $\theta.$
|
| 34 |
+
|
| 35 |
+
I will choose $\theta = \pi/2,$ since it is the simplest and most natural choice.
|
| 36 |
+
|
| 37 |
+
Therefore, the polar coordinates of the point $(0,3)$ are $(3,\pi/2).$
|
| 38 |
+
|
| 39 |
+
# Answer
|
| 40 |
+
|
| 41 |
+
(3,\pi/2)"""
|
| 42 |
+
|
| 43 |
+
begin_solution_tokens = tokenizer.encode("\n\n# Solution", add_special_tokens=False)[1:]
|
| 44 |
+
scoring_tokens = tokenizer.encode("\n\n", add_special_tokens=False)[1:]
|
| 45 |
+
eos_token = tokenizer.eos_token_id
|
| 46 |
+
|
| 47 |
+
input_ids = tokenizer.encode(qa_example)
|
| 48 |
+
|
| 49 |
+
begin_solution_flag = False
|
| 50 |
+
|
| 51 |
+
candidate_positions = []
|
| 52 |
+
|
| 53 |
+
for start_idx in range(len(input_ids)):
|
| 54 |
+
if tuple(input_ids[start_idx:start_idx+len(begin_solution_tokens)]) == tuple(begin_solution_tokens):
|
| 55 |
+
begin_solution_flag = True
|
| 56 |
+
|
| 57 |
+
if begin_solution_flag and tuple(input_ids[start_idx:start_idx+len(scoring_tokens)]) == tuple(scoring_tokens):
|
| 58 |
+
candidate_positions.append(start_idx)
|
| 59 |
+
|
| 60 |
+
if input_ids[start_idx] == eos_token:
|
| 61 |
+
candidate_positions.append(start_idx)
|
| 62 |
+
break
|
| 63 |
+
|
| 64 |
+
# maybe delete the first and the second to last candidate_positions
|
| 65 |
+
# because they are "\n\n" after "# Solution" and after "# Answer"
|
| 66 |
+
del candidate_positions[0]
|
| 67 |
+
del candidate_positions[-2]
|
| 68 |
+
|
| 69 |
+
input_tensor = torch.tensor([input_ids])
|
| 70 |
+
candidate_positions = torch.tensor(candidate_positions)
|
| 71 |
+
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
logits = model(input_tensor).logits
|
| 74 |
+
scores =logits.mean(dim=-1)
|
| 75 |
+
step_scores = scores[0][candidate_positions]
|
| 76 |
+
step_probs = torch.sigmoid(step_scores)
|
| 77 |
+
|
| 78 |
+
print(step_probs)
|
| 79 |
+
|
| 80 |
+
# only the last logprob is orm's output
|
| 81 |
+
# tensor([0.4531, 0.3882, 0.3748, 0.4785, 0.4087, 0.3166, 0.3040, 0.2295, 0.2628, 0.2568])
|
| 82 |
+
```
|