|
|
import hashlib |
|
|
import inspect |
|
|
from copy import copy |
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
import json |
|
|
import numpy as np |
|
|
|
|
|
from swift.llm import InferRequest, RequestConfig |
|
|
from swift.utils import get_logger |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
def get_messages_md5(row: Dict[str, Any]): |
|
|
row = copy(row) |
|
|
row.pop('choices', None) |
|
|
serialized = json.dumps(row, sort_keys=True) |
|
|
return hashlib.md5(serialized.encode('utf-8')).hexdigest() |
|
|
|
|
|
|
|
|
def get_reward(model: Any, |
|
|
infer_requests: List[InferRequest], |
|
|
request_config: RequestConfig = None, |
|
|
ground_truths: List[str] = None, |
|
|
threshold: Optional[float] = None): |
|
|
"""Get reward from an RM model. |
|
|
|
|
|
Args: |
|
|
model: The model instance or an RM evaluator |
|
|
infer_requests: Infer requests sent to the model |
|
|
request_config: Infer config |
|
|
ground_truths: The ground truth list |
|
|
threshold: An optional threshold to generate the mask |
|
|
|
|
|
Returns: |
|
|
Tuple |
|
|
Index 0: The min-max normalized scores matched the infer_requests |
|
|
Index 1: The mask filtered by the threshold |
|
|
""" |
|
|
from swift.llm import InferEngine |
|
|
infer_func = model.infer if isinstance(model, InferEngine) else model.__call__ |
|
|
parameters = inspect.signature(infer_func).parameters |
|
|
gt_param = {} |
|
|
if 'ground_truths' in parameters: |
|
|
gt_param = {'ground_truths': ground_truths} |
|
|
if isinstance(infer_requests[0], dict): |
|
|
infer_requests = [InferRequest(messages=req['messages']) for req in infer_requests] |
|
|
rewards = infer_func(infer_requests, request_config=request_config, **gt_param) |
|
|
from swift.llm.infer.protocol import ChatCompletionResponse |
|
|
if isinstance(rewards[0], ChatCompletionResponse): |
|
|
print('reward:', rewards[0].choices[0].message.content) |
|
|
if isinstance(rewards[0].choices[0].message.content, str): |
|
|
rewards = [float(r.choices[0].message.content.strip('[]')) for r in rewards] |
|
|
elif isinstance(rewards[0].choices[0].message.content, list): |
|
|
rewards = [float(min(r.choices[0].message.content)) for r in rewards] |
|
|
else: |
|
|
rewards = [float(r.choices[0].message.content) for r in rewards] |
|
|
arr = [] |
|
|
for reward in rewards: |
|
|
if isinstance(reward, (list, tuple)): |
|
|
arr.append(min(reward)) |
|
|
else: |
|
|
arr.append(float(reward)) |
|
|
|
|
|
_mask = np.array([True] * len(arr)) |
|
|
if threshold is not None: |
|
|
|
|
|
_mask = np.array([a > threshold for a in arr]) |
|
|
|
|
|
def normalize(arr): |
|
|
min_val = np.min(arr) |
|
|
max_val = np.max(arr) |
|
|
if min_val == max_val: |
|
|
if min_val == 0: |
|
|
constant_value = 0.0 |
|
|
else: |
|
|
constant_value = min(1.0, min_val) |
|
|
return np.full_like(arr, fill_value=constant_value, dtype=np.float64) |
|
|
normalized = (arr - min_val) / (max_val - min_val + 1e-5) |
|
|
return normalized |
|
|
|
|
|
return normalize(arr), _mask |
|
|
|
|
|
|
|
|
def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs): |
|
|
if isinstance(infer_engines, list): |
|
|
assert len(infer_engines) >= len(request_configs) >= len(infer_requests) |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
n = len(infer_requests) |
|
|
with ThreadPoolExecutor(max_workers=n) as executor: |
|
|
futures = { |
|
|
executor.submit(perform_infer, infer_engines[i], infer_requests[i], request_configs[i], **infer_kwargs): |
|
|
i |
|
|
for i in range(n) |
|
|
} |
|
|
responses = [] |
|
|
for future in as_completed(futures): |
|
|
task_id = futures[future] |
|
|
try: |
|
|
responses += future.result() |
|
|
except Exception as e: |
|
|
logger.info(f'Perform infer task: {task_id} get an error: {e}') |
|
|
return responses |
|
|
elif isinstance(infer_requests, list): |
|
|
responses = [] |
|
|
if isinstance(request_configs, list): |
|
|
assert len(infer_requests) <= len(request_configs) |
|
|
for i in range(len(infer_requests)): |
|
|
responses += infer_engines.infer( |
|
|
[infer_requests[i]], |
|
|
request_configs[i], |
|
|
**infer_kwargs, |
|
|
) |
|
|
elif isinstance(request_configs, RequestConfig): |
|
|
for infer_request in infer_requests: |
|
|
responses += infer_engines.infer( |
|
|
[infer_request], |
|
|
request_configs, |
|
|
**infer_kwargs, |
|
|
) |
|
|
return responses |
|
|
return infer_engines.infer( |
|
|
[infer_requests], |
|
|
request_configs, |
|
|
**infer_kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
def collect_from_mct(monte_carlo_tree, collect_filter_threshold): |
|
|
from transformers.utils import strtobool |
|
|
if isinstance(monte_carlo_tree, str): |
|
|
monte_carlo_tree = json.loads(monte_carlo_tree) |
|
|
|
|
|
def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: list[float]): |
|
|
_prefer_pairs, _correct_answers, _incorrect_answers = [], [], [] |
|
|
_outcome_rewards = _outcome_rewards[:] + [collect_curr_node['outcome_reward']] |
|
|
_process_rewards = _process_rewards[:] + [collect_curr_node['process_reward']] |
|
|
if len(collect_curr_node['children']) > 0: |
|
|
for child in collect_curr_node['children']: |
|
|
p, c, i = _collect(child, _outcome_rewards, _process_rewards) |
|
|
_prefer_pairs += p |
|
|
_correct_answers += c |
|
|
_incorrect_answers += i |
|
|
sorted_children = sorted(collect_curr_node['children'], key=lambda x: x['outcome_reward']) |
|
|
if sorted_children[-1]['outcome_reward'] - sorted_children[0]['outcome_reward'] > collect_filter_threshold: |
|
|
|
|
|
prefer_pair = { |
|
|
'path': 'ки\n'.join(collect_curr_node['path']), |
|
|
'good': sorted_children[-1]['path'][-1], |
|
|
'good_score': sorted_children[-1]['outcome_reward'], |
|
|
'bad': sorted_children[0]['path'][-1], |
|
|
'bad_score': sorted_children[0]['outcome_reward'], |
|
|
} |
|
|
_prefer_pairs.append(prefer_pair) |
|
|
if strtobool(collect_curr_node['terminated']): |
|
|
_answer = { |
|
|
'answer': 'ки\n'.join(collect_curr_node['path']), |
|
|
'mean_outcome_reward': np.mean(_outcome_rewards), |
|
|
'min_outcome_reward': np.min(_outcome_rewards), |
|
|
'mean_process_reward': np.mean(_process_rewards), |
|
|
'min_process_reward': np.min(_process_rewards), |
|
|
} |
|
|
if strtobool(collect_curr_node['correct']): |
|
|
_correct_answers.append(_answer) |
|
|
else: |
|
|
_incorrect_answers.append(_answer) |
|
|
return _prefer_pairs, _correct_answers, _incorrect_answers |
|
|
|
|
|
_root = monte_carlo_tree |
|
|
prefer_pairs, correct_answers, incorrect_answers = _collect(_root, [], []) |
|
|
return prefer_pairs, correct_answers, incorrect_answers |
|
|
|