| import time |
| import traceback |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from copy import deepcopy |
|
|
| import json |
| import numpy as np |
|
|
| from swift.llm import InferRequest, SamplingArguments |
| from swift.llm.infer.protocol import UsageInfo |
| from swift.utils import get_logger |
| from .base import Sampler |
| from .utils import get_reward, perform_infer |
|
|
| logger = get_logger() |
|
|
| NXT_PROMPT = """Continue. |
| """ |
|
|
| next_message = { |
| 'role': 'user', |
| 'content': NXT_PROMPT, |
| } |
|
|
|
|
| class LanguageNode: |
|
|
| def __init__( |
| self, |
| step: str = None, |
| sep_token: str = None, |
| parent: 'LanguageNode' = None, |
| ): |
| self.parent = parent |
|
|
| if sep_token: |
| self.sep_token = sep_token |
| else: |
| self.sep_token = parent.sep_token |
|
|
| if parent: |
| self.path = parent.path[:] + [step] |
| self.answer = parent.answer + step + self.sep_token |
| self.depth = parent.depth + 1 |
| else: |
| self.path = [] |
| self.answer = '' |
| self.depth = 0 |
|
|
| self.active_children = [] |
| self.children = [] |
| self.visit_count = 0 |
| self.process_reward = 0.0 |
| self.outcome_reward = 0.0 |
| self.terminated = False |
| self.correct = False |
|
|
| def is_leaf(self): |
| return len(self.children) == 0 |
|
|
| def is_root(self): |
| return self.parent is None |
|
|
| def visit(self): |
| self.visit_count += 1 |
|
|
| def init_and_update_value(self, value): |
| self.outcome_reward = (self.outcome_reward * self.visit_count + value) / (self.visit_count + 1) |
|
|
| def add_child(self, child: 'LanguageNode'): |
| self.children.append(child) |
| if not child.terminated: |
| self.active_children.append(child) |
|
|
| def collect(self): |
| result = { |
| 'path': self.path, |
| 'depth': self.depth, |
| 'visit_count': self.visit_count, |
| 'process_reward': self.process_reward, |
| 'outcome_reward': self.outcome_reward, |
| 'terminated': str(self.terminated), |
| 'correct': str(self.correct), |
| 'children': [child.collect() for child in self.children], |
| } |
| return result |
|
|
| def __lt__(self, other): |
| return self.outcome_reward < other.outcome_reward |
|
|
|
|
| class MctsSampler(Sampler): |
|
|
| def __init__(self, input_args: SamplingArguments): |
| super().__init__(input_args) |
| self.usage_info = UsageInfo(0, 0, 0) |
|
|
| def _prepare_model_tokenizer(self): |
| args = self.args |
| self.infer_kwargs = {} |
| if args.sampler_engine == 'client': |
| from swift.llm import InferClient |
| api_key = args.api_key |
| base_url = args.base_url |
| self.infer_engine = [ |
| InferClient(base_url=base_url, api_key=api_key) for _ in range(args.num_return_sequences) |
| ] |
| self.infer_kwargs['model'] = args.model |
| else: |
| _Engine = self.get_infer_engine() |
| self.infer_engine = _Engine(self.args.model, model_type=self.args.model_type, **self.args.engine_kwargs) |
|
|
| def get_infer_engine(self): |
| if self.args.sampler_engine == 'pt': |
| from swift.llm import PtEngine |
| _Engine = PtEngine |
| elif self.args.sampler_engine == 'vllm': |
| from swift.llm import VllmEngine |
| _Engine = VllmEngine |
| elif self.args.sampler_engine == 'lmdeploy': |
| from swift.llm import LmdeployEngine |
| _Engine = LmdeployEngine |
| elif self.args.sampler_engine == 'no': |
| _Engine = None |
| else: |
| raise ValueError(f'Cannot find engine name: {self.args.sampler_engine}') |
| return _Engine |
|
|
| def _prepare_template(self) -> None: |
| |
| self._prepare_request_configs() |
|
|
| def _prepare_request_configs(self): |
| _args = self.args |
| request_config = _args.get_request_config() |
| request_config.stop = _args.stop_words |
| request_config.seed = _args.seed |
| self.expand_request_configs = [] |
| self.rollout_request_configs = [] |
| for i in range(_args.num_return_sequences): |
| expand_request_config = deepcopy(request_config) |
| expand_request_config.n = 1 |
| expand_request_config.num_beams = expand_request_config.n |
| expand_request_config.seed += i |
| self.expand_request_configs.append(expand_request_config) |
| rollout_request_config = deepcopy(request_config) |
| rollout_request_config.max_tokens = 500 |
| rollout_request_config.temperature = 0.0 |
| rollout_request_config.n = 1 |
| self.rollout_request_configs.append(rollout_request_config) |
|
|
| def update_usage_info(self, response): |
| for key, value in self.usage_info.__dict__.items(): |
| update_value = getattr(response.usage, key, None) + value |
| setattr(self.usage_info, key, update_value) |
|
|
| def search_single(self, query, ground_truth): |
|
|
| def _uct(uct_curr_node: LanguageNode): |
| alpha = _args.process_reward_rate |
| value = alpha * uct_curr_node.process_reward + (1 - alpha) * uct_curr_node.outcome_reward |
| if uct_curr_node.is_root(): |
| return value |
|
|
| exploitation_score = value |
| exploration_score = ( |
| _args.exploration_rate |
| * np.sqrt(np.log(uct_curr_node.parent.visit_count + 1) / (uct_curr_node.visit_count + 1))) |
|
|
| return exploration_score + exploitation_score |
|
|
| def _select(select_curr_node: LanguageNode): |
| while not select_curr_node.is_leaf(): |
| select_curr_node = max(select_curr_node.active_children, key=lambda x: _uct(x)) |
| return select_curr_node |
|
|
| def _expand(expand_curr_node: LanguageNode): |
| n = _args.num_return_sequences - len(expand_curr_node.children) |
| if expand_curr_node.is_root(): |
| infer_requests = [InferRequest(system_message + [prompt_message]) for _ in range(n)] |
| else: |
| history_message = { |
| 'role': 'assistant', |
| 'content': expand_curr_node.answer, |
| } |
| infer_request = InferRequest(system_message + [prompt_message, history_message, next_message]) |
| infer_requests = [infer_request for _ in range(n)] |
|
|
| |
| |
| |
| expand_iter_index = 0 |
| while True: |
| responses = perform_infer(self.infer_engine, infer_requests, self.expand_request_configs, |
| **self.infer_kwargs) |
| if len(responses) > 0: |
| break |
| if expand_iter_index == 5: |
| raise ValueError('Expand should not return any response') |
| expand_iter_index += 1 |
| |
|
|
| |
| |
| orm_infer_requests = [] |
| unique_output = set() |
| for response in responses: |
| self.update_usage_info(response) |
| output = response.choices[0].message.content.rstrip(sep_token + '\n').split(sep_token)[0] |
| if output in unique_output: |
| continue |
| unique_output.add(output) |
| orm_infer_requests.append(InferRequest([{'role': 'assistant', 'content': output}])) |
| child = LanguageNode(step=output, parent=expand_curr_node) |
| if self.orm_model.check_terminate(child.answer)[0]: |
| child.terminated = True |
| expand_curr_node.add_child(child) |
|
|
| |
| orm_score, _orm_mask = get_reward( |
| self.orm_model, |
| orm_infer_requests, |
| ground_truths=[ground_truth] * len(orm_infer_requests), |
| threshold=0.0) |
| |
| for child, score in zip(expand_curr_node.children, orm_score): |
| if child.terminated: |
| child.init_and_update_value(score) |
| child.correct = score > 0.9 |
| terminated_nodes.append(child) |
|
|
| |
| if self.prm_model: |
| prm_infer_requests = [] |
| for child in expand_curr_node.children: |
| prm_message = {'role': 'assistant', 'content': child.answer} |
| prm_infer_requests.append(InferRequest([prompt_message, prm_message])) |
| prm_score, _prm_mask = get_reward( |
| self.prm_model, |
| prm_infer_requests, |
| ground_truths=[ground_truth] * len(prm_infer_requests), |
| threshold=0.0) |
| for child, score in zip(expand_curr_node.children, prm_score): |
| child.process_reward = score |
| |
|
|
| def _rollout(rollout_curr_node: LanguageNode): |
| rollout_depth = 0 |
| rollout_nodes = {} |
| for i in range(len(rollout_curr_node.active_children)): |
| rollout_nodes[i] = { |
| 'node': rollout_curr_node.active_children[i], |
| 'history_messages': { |
| 'role': 'assistant', |
| 'content': rollout_curr_node.active_children[i].answer, |
| }, |
| } |
| active_rollout_nodes = list(rollout_nodes.keys()) |
| while len(active_rollout_nodes) > 0 and rollout_depth < _args.rollout_depth: |
| |
| infer_requests = [ |
| InferRequest(system_message |
| + [prompt_message, rollout_nodes[index]['history_messages'], next_message]) |
| for index in active_rollout_nodes |
| ] |
| |
| |
| rollout_iter_index = 0 |
| while True: |
| responses = perform_infer(self.infer_engine, infer_requests, self.rollout_request_configs, |
| **self.infer_kwargs) |
| if len(responses) > 0: |
| break |
| if rollout_iter_index == 5: |
| raise ValueError('Rollout should not return any response') |
| rollout_iter_index += 1 |
| |
|
|
| |
| orm_infer_requests = [] |
| end_paths = [] |
| for index, response in zip(active_rollout_nodes, responses): |
| self.update_usage_info(response) |
| output = response.choices[0].message.content.rstrip(sep_token |
| + '\n').split(sep_token)[0] + sep_token + '\n' |
| rollout_nodes[index]['history_messages']['content'] += output |
| end_paths.append(rollout_nodes[index]['history_messages']['content']) |
| orm_infer_requests.append(InferRequest([rollout_nodes[index]['history_messages']])) |
| |
|
|
| |
| orm_score, _orm_mask = get_reward( |
| self.orm_model, |
| orm_infer_requests, |
| ground_truths=[ground_truth] * len(infer_requests), |
| threshold=0.0) |
| |
| terminated_state = self.orm_model.check_terminate(end_paths) |
| for index, score, terminated in zip(active_rollout_nodes, orm_score, terminated_state): |
| if terminated: |
| rollout_curr_node.active_children[index].init_and_update_value(score) |
| if score > 0.9: |
| rollout_correct_answers.append(rollout_nodes[index]['history_messages']['content']) |
| else: |
| rollout_incorrect_answers.append(rollout_nodes[index]['history_messages']['content']) |
| rollout_nodes.pop(index) |
| active_rollout_nodes = list(rollout_nodes.keys()) |
| rollout_depth += 1 |
|
|
| def _back_propagate(back_curr_node: LanguageNode): |
| while back_curr_node: |
| if back_curr_node == curr_node: |
| best_child_value = max([child.outcome_reward for child in back_curr_node.children]) |
| back_curr_node.init_and_update_value(best_child_value) |
| last_child_value = back_curr_node.outcome_reward |
| else: |
| back_curr_node.init_and_update_value(last_child_value) |
| last_child_value = back_curr_node.outcome_reward |
| back_curr_node.visit() |
| if len(back_curr_node.active_children) == 0: |
| back_curr_node.terminated = True |
| if not back_curr_node.is_root(): |
| back_curr_node.parent.active_children.remove(back_curr_node) |
| back_curr_node = back_curr_node.parent |
|
|
| _args = self.args |
| system_message = [] + _args.system_message |
| sep_token = _args.stop_words[0] + '\n' |
| _root = LanguageNode(sep_token=sep_token) |
| prompt_message = { |
| 'role': 'user', |
| 'content': query, |
| } |
|
|
| rollout_correct_answers, rollout_incorrect_answers, terminated_nodes = [], [], [] |
| iter_count = 0 |
| stop_reason = None |
| while True: |
| logger.info(f'iter_count: {iter_count}' + '.' * 10) |
| s_time = time.time() |
| curr_node = _select(_root) |
| logger.debug('select' + '=' * 10 + f'time: {time.time() - s_time}') |
| s_time = time.time() |
| _expand(curr_node) |
| logger.debug('expand' + '=' * 10 + f'time: {time.time() - s_time}') |
| if curr_node.depth > _args.rollout_start_depth: |
| s_time = time.time() |
| _rollout(curr_node) |
| logger.debug('rollout' + '=' * 10 + f'time: {time.time() - s_time}') |
| s_time = time.time() |
| _back_propagate(curr_node) |
| logger.debug('back propagate' + '=' * 10 + f'time: {time.time() - s_time}') |
| if len(rollout_correct_answers) + len(rollout_incorrect_answers) >= 2 * _args.num_return_sequences: |
| if 4 * len(rollout_incorrect_answers) < len(rollout_correct_answers): |
| stop_reason = 'too easy' |
| break |
| elif 4 * len(rollout_correct_answers) < len(rollout_incorrect_answers): |
| stop_reason = 'too hard' |
| break |
| if _root.terminated: |
| stop_reason = 'root terminated' |
| break |
| if len(terminated_nodes) >= _args.num_return_sequences: |
| stop_reason = 'enough nodes' |
| break |
| if iter_count >= _args.max_iterations: |
| stop_reason = 'max_iterations' |
| break |
| iter_count += 1 |
| logger.info(f'stop_reason: {stop_reason}') |
| |
| |
|
|
| monte_carlo_tree = _root.collect() |
| result = { |
| 'query': query, |
| 'ground_truth': ground_truth, |
| 'rollout_correct_answers': rollout_correct_answers, |
| 'rollout_incorrect_answers': rollout_incorrect_answers, |
| 'monte_carlo_tree': monte_carlo_tree, |
| } |
| result_json = json.dumps(result, ensure_ascii=False) |
| logger.info(result_json) |
| return result_json |
|
|
| def do_sample(self, data): |
| if not isinstance(data, list): |
| data = [data] |
| generated = [] |
| for item in data: |
| logger.info(f'time: {time.ctime(time.time())}') |
| try: |
| messages = item['messages'][0] |
| query = messages[0]['content'] |
| ground_truth = messages[1]['content'] |
| generated.append(self.search_single(query, ground_truth) + '\n') |
| except Exception as e: |
| logger.error(f'Error: {e}') |
| logger.error(f'Traceback: {traceback.format_exc()}') |
| return generated |
|
|