import hashlib import inspect from copy import copy from typing import Any, Dict, List, Optional import json import numpy as np from swift.llm import InferRequest, RequestConfig from swift.utils import get_logger logger = get_logger() def get_messages_md5(row: Dict[str, Any]): row = copy(row) row.pop('choices', None) serialized = json.dumps(row, sort_keys=True) return hashlib.md5(serialized.encode('utf-8')).hexdigest() def get_reward(model: Any, infer_requests: List[InferRequest], request_config: RequestConfig = None, ground_truths: List[str] = None, threshold: Optional[float] = None): """Get reward from an RM model. Args: model: The model instance or an RM evaluator infer_requests: Infer requests sent to the model request_config: Infer config ground_truths: The ground truth list threshold: An optional threshold to generate the mask Returns: Tuple Index 0: The min-max normalized scores matched the infer_requests Index 1: The mask filtered by the threshold """ from swift.llm import InferEngine infer_func = model.infer if isinstance(model, InferEngine) else model.__call__ parameters = inspect.signature(infer_func).parameters gt_param = {} if 'ground_truths' in parameters: gt_param = {'ground_truths': ground_truths} if isinstance(infer_requests[0], dict): infer_requests = [InferRequest(messages=req['messages']) for req in infer_requests] rewards = infer_func(infer_requests, request_config=request_config, **gt_param) from swift.llm.infer.protocol import ChatCompletionResponse if isinstance(rewards[0], ChatCompletionResponse): print('reward:', rewards[0].choices[0].message.content) if isinstance(rewards[0].choices[0].message.content, str): rewards = [float(r.choices[0].message.content.strip('[]')) for r in rewards] elif isinstance(rewards[0].choices[0].message.content, list): rewards = [float(min(r.choices[0].message.content)) for r in rewards] else: rewards = [float(r.choices[0].message.content) for r in rewards] arr = [] for reward in rewards: if isinstance(reward, (list, tuple)): arr.append(min(reward)) else: arr.append(float(reward)) _mask = np.array([True] * len(arr)) if threshold is not None: # > not >=, orm caller passes 0, which will cause error _mask = np.array([a > threshold for a in arr]) def normalize(arr): min_val = np.min(arr) max_val = np.max(arr) if min_val == max_val: if min_val == 0: constant_value = 0.0 else: constant_value = min(1.0, min_val) return np.full_like(arr, fill_value=constant_value, dtype=np.float64) normalized = (arr - min_val) / (max_val - min_val + 1e-5) return normalized return normalize(arr), _mask def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs): if isinstance(infer_engines, list): assert len(infer_engines) >= len(request_configs) >= len(infer_requests) from concurrent.futures import ThreadPoolExecutor, as_completed n = len(infer_requests) with ThreadPoolExecutor(max_workers=n) as executor: futures = { executor.submit(perform_infer, infer_engines[i], infer_requests[i], request_configs[i], **infer_kwargs): i for i in range(n) } responses = [] for future in as_completed(futures): task_id = futures[future] try: responses += future.result() except Exception as e: logger.info(f'Perform infer task: {task_id} get an error: {e}') return responses elif isinstance(infer_requests, list): responses = [] if isinstance(request_configs, list): assert len(infer_requests) <= len(request_configs) for i in range(len(infer_requests)): responses += infer_engines.infer( [infer_requests[i]], request_configs[i], **infer_kwargs, ) elif isinstance(request_configs, RequestConfig): for infer_request in infer_requests: responses += infer_engines.infer( [infer_request], request_configs, **infer_kwargs, ) return responses return infer_engines.infer( [infer_requests], request_configs, **infer_kwargs, ) def collect_from_mct(monte_carlo_tree, collect_filter_threshold): from transformers.utils import strtobool if isinstance(monte_carlo_tree, str): monte_carlo_tree = json.loads(monte_carlo_tree) def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: list[float]): _prefer_pairs, _correct_answers, _incorrect_answers = [], [], [] _outcome_rewards = _outcome_rewards[:] + [collect_curr_node['outcome_reward']] _process_rewards = _process_rewards[:] + [collect_curr_node['process_reward']] if len(collect_curr_node['children']) > 0: for child in collect_curr_node['children']: p, c, i = _collect(child, _outcome_rewards, _process_rewards) _prefer_pairs += p _correct_answers += c _incorrect_answers += i sorted_children = sorted(collect_curr_node['children'], key=lambda x: x['outcome_reward']) if sorted_children[-1]['outcome_reward'] - sorted_children[0]['outcome_reward'] > collect_filter_threshold: # TODO: filter with visit count prefer_pair = { 'path': 'ки\n'.join(collect_curr_node['path']), 'good': sorted_children[-1]['path'][-1], 'good_score': sorted_children[-1]['outcome_reward'], 'bad': sorted_children[0]['path'][-1], 'bad_score': sorted_children[0]['outcome_reward'], } _prefer_pairs.append(prefer_pair) if strtobool(collect_curr_node['terminated']): _answer = { 'answer': 'ки\n'.join(collect_curr_node['path']), 'mean_outcome_reward': np.mean(_outcome_rewards), 'min_outcome_reward': np.min(_outcome_rewards), 'mean_process_reward': np.mean(_process_rewards), 'min_process_reward': np.min(_process_rewards), } if strtobool(collect_curr_node['correct']): _correct_answers.append(_answer) else: _incorrect_answers.append(_answer) return _prefer_pairs, _correct_answers, _incorrect_answers _root = monte_carlo_tree prefer_pairs, correct_answers, incorrect_answers = _collect(_root, [], []) return prefer_pairs, correct_answers, incorrect_answers