Update README.md
Browse files
README.md
CHANGED
|
@@ -120,111 +120,7 @@ print(step_reward) # [[0.796875, 0.185546875, -0.0625, 0.078125]]
|
|
| 120 |
# torch.tensor(step_reward).sum(dim=-1)
|
| 121 |
```
|
| 122 |
|
| 123 |
-
2.
|
| 124 |
-
|
| 125 |
-
```python
|
| 126 |
-
import numpy as np
|
| 127 |
-
import torch
|
| 128 |
-
from datasets import load_dataset
|
| 129 |
-
from tqdm import tqdm
|
| 130 |
-
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
| 131 |
-
|
| 132 |
-
ds_names = ["GSM8K", "MATH500"]
|
| 133 |
-
ds = [
|
| 134 |
-
load_dataset(
|
| 135 |
-
f"RLHFlow/Deepseek-{ds_name}-Test"
|
| 136 |
-
)['test'] for ds_name in ds_names
|
| 137 |
-
]
|
| 138 |
-
|
| 139 |
-
def make_step_rewards(logits, token_masks):
|
| 140 |
-
all_scores_res = []
|
| 141 |
-
for sample, token_mask in zip(logits, token_masks):
|
| 142 |
-
# sample: (seq_len, num_labels)
|
| 143 |
-
probs = sample[token_mask].softmax(dim=-1) # (num_steps, 2)
|
| 144 |
-
process_reward = probs[:, 1] - probs[:, 0] # (num_steps,)
|
| 145 |
-
# weighted sum to approx. min, highly recommend when BoN eval and Fine-tuning LLM
|
| 146 |
-
weight = torch.softmax(
|
| 147 |
-
-process_reward / 0.1,
|
| 148 |
-
dim=-1,
|
| 149 |
-
)
|
| 150 |
-
process_reward = weight * process_reward
|
| 151 |
-
all_scores_res.append(process_reward.cpu().tolist())
|
| 152 |
-
return all_scores_res
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
model_name = "jinachris/PURE-PRM-7B"
|
| 156 |
-
device = "auto"
|
| 157 |
-
|
| 158 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 159 |
-
model_name,
|
| 160 |
-
trust_remote_code=True,
|
| 161 |
-
)
|
| 162 |
-
model = AutoModelForTokenClassification.from_pretrained(
|
| 163 |
-
model_name,
|
| 164 |
-
device_map=device,
|
| 165 |
-
torch_dtype=torch.bfloat16,
|
| 166 |
-
trust_remote_code=True,
|
| 167 |
-
).eval()
|
| 168 |
-
|
| 169 |
-
step_separator = "\n"
|
| 170 |
-
step_separator_token = tokenizer(
|
| 171 |
-
step_separator,
|
| 172 |
-
add_special_tokens=False,
|
| 173 |
-
return_tensors='pt',
|
| 174 |
-
)['input_ids']
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
for ds_item, ds_name in zip(ds, ds_names):
|
| 178 |
-
# sampled_ids = np.random.choice(range(len(ds_item)), size=100, replace=False)
|
| 179 |
-
correct = 0
|
| 180 |
-
total = 0
|
| 181 |
-
for idx in tqdm(range(len(ds_item)), desc=f"Processing questions in {ds_name}"):
|
| 182 |
-
question = ds_item['prompt'][idx]
|
| 183 |
-
answers = ds_item['answers'][idx]
|
| 184 |
-
labels = ds_item['label'][idx]
|
| 185 |
-
outcome_scores = []
|
| 186 |
-
|
| 187 |
-
question_ids = tokenizer(
|
| 188 |
-
question,
|
| 189 |
-
add_special_tokens=False,
|
| 190 |
-
return_tensors='pt',
|
| 191 |
-
)['input_ids']
|
| 192 |
-
for answer in tqdm(answers, desc="Processing answers"):
|
| 193 |
-
steps = [i.rstrip() for i in answer.split("\n\n")]
|
| 194 |
-
input_ids = question_ids.clone()
|
| 195 |
-
|
| 196 |
-
score_ids = []
|
| 197 |
-
for step in steps:
|
| 198 |
-
step_ids = tokenizer(
|
| 199 |
-
step,
|
| 200 |
-
add_special_tokens=False,
|
| 201 |
-
return_tensors='pt',
|
| 202 |
-
)['input_ids']
|
| 203 |
-
input_ids = torch.cat(
|
| 204 |
-
[input_ids, step_ids, step_separator_token],
|
| 205 |
-
dim=-1,
|
| 206 |
-
)
|
| 207 |
-
score_ids.append(input_ids.size(-1) - 1)
|
| 208 |
-
|
| 209 |
-
input_ids = input_ids.to(model.device, dtype=torch.long)
|
| 210 |
-
token_masks = torch.zeros_like(input_ids, dtype=torch.bool)
|
| 211 |
-
token_masks[0, score_ids] = True
|
| 212 |
-
assert torch.all(input_ids[token_masks].to("cpu") == step_separator_token)
|
| 213 |
-
|
| 214 |
-
with torch.no_grad():
|
| 215 |
-
logits = model(input_ids).logits
|
| 216 |
-
step_reward = make_step_rewards(logits, token_masks)
|
| 217 |
-
outcome_reward = torch.tensor(step_reward).sum(dim=-1)
|
| 218 |
-
|
| 219 |
-
# TODO: batch input & output
|
| 220 |
-
outcome_scores.append(outcome_reward.item())
|
| 221 |
-
|
| 222 |
-
best_idx = np.argmax(outcome_scores)
|
| 223 |
-
if labels[best_idx] == 1:
|
| 224 |
-
correct += 1
|
| 225 |
-
total += 1
|
| 226 |
-
print(f"Accuracy on {ds_name}: {correct / total}")
|
| 227 |
-
```
|
| 228 |
|
| 229 |
## Citation
|
| 230 |
|
|
|
|
| 120 |
# torch.tensor(step_reward).sum(dim=-1)
|
| 121 |
```
|
| 122 |
|
| 123 |
+
2. For evaluation using Best-of-N method or on ProcessBench and PRMBench, refer to [our github repository](https://github.com/CJReinforce/PURE/tree/verl/PRM/eval).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
## Citation
|
| 126 |
|