File size: 7,002 Bytes
bcdf9fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | # Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from typing import Callable, Optional
import torch
from transformers import PreTrainedTokenizer
from verl import DataProto
from verl.utils.reward_score import _default_compute_score
async def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0):
loop = asyncio.get_running_loop()
try:
# Ensure process_completion is called properly
tasks = [
asyncio.wait_for(
loop.run_in_executor(
executor,
partial(evaluation_func, task, completion, reference, task_extra_info), # Ensure synchronous
),
timeout=timeout,
)
]
return await asyncio.gather(*tasks)
except asyncio.TimeoutError:
print(f"Timeout occurred for completion: {completion}")
return None # Default value for timed-out rows
except Exception as e:
print(f"Error processing completion: {completion[:10]}, Error: {e}")
return None # Default value for failed rows
async def parallel_compute_score_async(evaluation_func, completions, references, tasks, extra_info=None, num_processes=64):
scores = []
with ProcessPoolExecutor(max_workers=num_processes) as executor:
if extra_info is None:
extra_info = [None] * len(tasks)
# Create tasks for all rows
tasks_async = [single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0) for completion, reference, task, task_extra_info in zip(completions, references, tasks, extra_info)]
# to prevent very occasional starvation caused by some anomalous programs ( like infinite loop ), the exceptions in async programs will instantly halt the evaluation, and all summoned processes will be killed.
try:
results = await asyncio.gather(*tasks_async, return_exceptions=False)
except:
for pid, proc in executor._processes.items():
try:
proc.kill()
except Exception as kill_err:
print("shut down failed: " + str(kill_err))
raise
# Process results
for result, completion, reference, task in zip(results, completions, references, tasks):
if isinstance(result, Exception) or result is None:
# Handle failed or timed-out tasks
scores.append(0.0)
elif isinstance(result[0], (int, float, bool)):
scores.append(float(result[0]))
else:
scores.append(float(result[0][0]))
return scores
class PrimeRewardManager:
"""
The Reward Manager used in https://github.com/PRIME-RL/PRIME
"""
def __init__(
self,
tokenizer: PreTrainedTokenizer,
num_examine: int,
compute_score: Optional[Callable] = None,
reward_fn_key: str = "data_source",
) -> None:
self.tokenizer = tokenizer
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
self.compute_score = compute_score or _default_compute_score
self.reward_fn_key = reward_fn_key
def verify(self, data):
"""
verify the batch and save as ``acc`` tensor
"""
# batched scoring
prompt_ids = data.batch["prompts"]
response_ids = data.batch["responses"]
sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)
ground_truth = [data_item.non_tensor_batch["reward_model"]["ground_truth"] for data_item in data]
data_sources = data.non_tensor_batch[self.reward_fn_key]
extra_info = data.non_tensor_batch.get("extra_info", None)
assert len(sequences_str) == len(ground_truth) == len(data_sources)
try:
scores = asyncio.run(
parallel_compute_score_async(
self.compute_score,
sequences_str,
ground_truth,
data_sources,
extra_info=extra_info,
num_processes=64,
)
)
except asyncio.TimeoutError:
print("Global timeout in reward computing! Setting all as 0.")
scores = [0.0 for _ in range(len(sequences_str))]
except Exception as e:
print(f"Unexpected error in batched reward computing. Setting all as 0.: {e}")
scores = [0.0 for _ in range(len(sequences_str))]
data.batch["acc"] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device)
return scores
def __call__(self, data: DataProto, return_dict: bool = False):
"""We will expand this function gradually based on the available datasets"""
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if "rm_scores" in data.batch.keys():
return data.batch["rm_scores"]
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
already_print_data_sources = {}
# batched scoring
prompt_ids = data.batch["prompts"]
prompt_length = prompt_ids.shape[-1]
response_ids = data.batch["responses"]
valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(dim=-1)
sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)
data_sources = data.non_tensor_batch["data_source"]
scores = self.verify(data)
for i in range(len(data)):
data_source = data_sources[i]
reward_tensor[i, valid_response_length[i].item() - 1] = scores[i]
if data_source not in already_print_data_sources:
already_print_data_sources[data_source] = 0
if already_print_data_sources[data_source] < self.num_examine:
already_print_data_sources[data_source] += 1
print(sequences_str)
if return_dict:
return {"reward_tensor": reward_tensor}
else:
return reward_tensor
|