|
|
import time |
|
|
import traceback |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
from copy import deepcopy |
|
|
|
|
|
import json |
|
|
import numpy as np |
|
|
|
|
|
from swift.llm import InferRequest, SamplingArguments |
|
|
from swift.llm.infer.protocol import UsageInfo |
|
|
from swift.utils import get_logger |
|
|
from .base import Sampler |
|
|
from .utils import get_reward, perform_infer |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
NXT_PROMPT = """Continue. |
|
|
""" |
|
|
|
|
|
next_message = { |
|
|
'role': 'user', |
|
|
'content': NXT_PROMPT, |
|
|
} |
|
|
|
|
|
|
|
|
class LanguageNode: |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
step: str = None, |
|
|
sep_token: str = None, |
|
|
parent: 'LanguageNode' = None, |
|
|
): |
|
|
self.parent = parent |
|
|
|
|
|
if sep_token: |
|
|
self.sep_token = sep_token |
|
|
else: |
|
|
self.sep_token = parent.sep_token |
|
|
|
|
|
if parent: |
|
|
self.path = parent.path[:] + [step] |
|
|
self.answer = parent.answer + step + self.sep_token |
|
|
self.depth = parent.depth + 1 |
|
|
else: |
|
|
self.path = [] |
|
|
self.answer = '' |
|
|
self.depth = 0 |
|
|
|
|
|
self.active_children = [] |
|
|
self.children = [] |
|
|
self.visit_count = 0 |
|
|
self.process_reward = 0.0 |
|
|
self.outcome_reward = 0.0 |
|
|
self.terminated = False |
|
|
self.correct = False |
|
|
|
|
|
def is_leaf(self): |
|
|
return len(self.children) == 0 |
|
|
|
|
|
def is_root(self): |
|
|
return self.parent is None |
|
|
|
|
|
def visit(self): |
|
|
self.visit_count += 1 |
|
|
|
|
|
def init_and_update_value(self, value): |
|
|
self.outcome_reward = (self.outcome_reward * self.visit_count + value) / (self.visit_count + 1) |
|
|
|
|
|
def add_child(self, child: 'LanguageNode'): |
|
|
self.children.append(child) |
|
|
if not child.terminated: |
|
|
self.active_children.append(child) |
|
|
|
|
|
def collect(self): |
|
|
result = { |
|
|
'path': self.path, |
|
|
'depth': self.depth, |
|
|
'visit_count': self.visit_count, |
|
|
'process_reward': self.process_reward, |
|
|
'outcome_reward': self.outcome_reward, |
|
|
'terminated': str(self.terminated), |
|
|
'correct': str(self.correct), |
|
|
'children': [child.collect() for child in self.children], |
|
|
} |
|
|
return result |
|
|
|
|
|
def __lt__(self, other): |
|
|
return self.outcome_reward < other.outcome_reward |
|
|
|
|
|
|
|
|
class MctsSampler(Sampler): |
|
|
|
|
|
def __init__(self, input_args: SamplingArguments): |
|
|
super().__init__(input_args) |
|
|
self.usage_info = UsageInfo(0, 0, 0) |
|
|
|
|
|
def _prepare_model_tokenizer(self): |
|
|
args = self.args |
|
|
self.infer_kwargs = {} |
|
|
if args.sampler_engine == 'client': |
|
|
from swift.llm import InferClient |
|
|
api_key = args.api_key |
|
|
base_url = args.base_url |
|
|
self.infer_engine = [ |
|
|
InferClient(base_url=base_url, api_key=api_key) for _ in range(args.num_return_sequences) |
|
|
] |
|
|
self.infer_kwargs['model'] = args.model |
|
|
else: |
|
|
_Engine = self.get_infer_engine() |
|
|
self.infer_engine = _Engine(self.args.model, model_type=self.args.model_type, **self.args.engine_kwargs) |
|
|
|
|
|
def get_infer_engine(self): |
|
|
if self.args.sampler_engine == 'pt': |
|
|
from swift.llm import PtEngine |
|
|
_Engine = PtEngine |
|
|
elif self.args.sampler_engine == 'vllm': |
|
|
from swift.llm import VllmEngine |
|
|
_Engine = VllmEngine |
|
|
elif self.args.sampler_engine == 'lmdeploy': |
|
|
from swift.llm import LmdeployEngine |
|
|
_Engine = LmdeployEngine |
|
|
elif self.args.sampler_engine == 'no': |
|
|
_Engine = None |
|
|
else: |
|
|
raise ValueError(f'Cannot find engine name: {self.args.sampler_engine}') |
|
|
return _Engine |
|
|
|
|
|
def _prepare_template(self) -> None: |
|
|
|
|
|
self._prepare_request_configs() |
|
|
|
|
|
def _prepare_request_configs(self): |
|
|
_args = self.args |
|
|
request_config = _args.get_request_config() |
|
|
request_config.stop = _args.stop_words |
|
|
request_config.seed = _args.seed |
|
|
self.expand_request_configs = [] |
|
|
self.rollout_request_configs = [] |
|
|
for i in range(_args.num_return_sequences): |
|
|
expand_request_config = deepcopy(request_config) |
|
|
expand_request_config.n = 1 |
|
|
expand_request_config.num_beams = expand_request_config.n |
|
|
expand_request_config.seed += i |
|
|
self.expand_request_configs.append(expand_request_config) |
|
|
rollout_request_config = deepcopy(request_config) |
|
|
rollout_request_config.max_tokens = 500 |
|
|
rollout_request_config.temperature = 0.0 |
|
|
rollout_request_config.n = 1 |
|
|
self.rollout_request_configs.append(rollout_request_config) |
|
|
|
|
|
def update_usage_info(self, response): |
|
|
for key, value in self.usage_info.__dict__.items(): |
|
|
update_value = getattr(response.usage, key, None) + value |
|
|
setattr(self.usage_info, key, update_value) |
|
|
|
|
|
def search_single(self, query, ground_truth): |
|
|
|
|
|
def _uct(uct_curr_node: LanguageNode): |
|
|
alpha = _args.process_reward_rate |
|
|
value = alpha * uct_curr_node.process_reward + (1 - alpha) * uct_curr_node.outcome_reward |
|
|
if uct_curr_node.is_root(): |
|
|
return value |
|
|
|
|
|
exploitation_score = value |
|
|
exploration_score = ( |
|
|
_args.exploration_rate |
|
|
* np.sqrt(np.log(uct_curr_node.parent.visit_count + 1) / (uct_curr_node.visit_count + 1))) |
|
|
|
|
|
return exploration_score + exploitation_score |
|
|
|
|
|
def _select(select_curr_node: LanguageNode): |
|
|
while not select_curr_node.is_leaf(): |
|
|
select_curr_node = max(select_curr_node.active_children, key=lambda x: _uct(x)) |
|
|
return select_curr_node |
|
|
|
|
|
def _expand(expand_curr_node: LanguageNode): |
|
|
n = _args.num_return_sequences - len(expand_curr_node.children) |
|
|
if expand_curr_node.is_root(): |
|
|
infer_requests = [InferRequest(system_message + [prompt_message]) for _ in range(n)] |
|
|
else: |
|
|
history_message = { |
|
|
'role': 'assistant', |
|
|
'content': expand_curr_node.answer, |
|
|
} |
|
|
infer_request = InferRequest(system_message + [prompt_message, history_message, next_message]) |
|
|
infer_requests = [infer_request for _ in range(n)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
expand_iter_index = 0 |
|
|
while True: |
|
|
responses = perform_infer(self.infer_engine, infer_requests, self.expand_request_configs, |
|
|
**self.infer_kwargs) |
|
|
if len(responses) > 0: |
|
|
break |
|
|
if expand_iter_index == 5: |
|
|
raise ValueError('Expand should not return any response') |
|
|
expand_iter_index += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
orm_infer_requests = [] |
|
|
unique_output = set() |
|
|
for response in responses: |
|
|
self.update_usage_info(response) |
|
|
output = response.choices[0].message.content.rstrip(sep_token + '\n').split(sep_token)[0] |
|
|
if output in unique_output: |
|
|
continue |
|
|
unique_output.add(output) |
|
|
orm_infer_requests.append(InferRequest([{'role': 'assistant', 'content': output}])) |
|
|
child = LanguageNode(step=output, parent=expand_curr_node) |
|
|
if self.orm_model.check_terminate(child.answer)[0]: |
|
|
child.terminated = True |
|
|
expand_curr_node.add_child(child) |
|
|
|
|
|
|
|
|
orm_score, _orm_mask = get_reward( |
|
|
self.orm_model, |
|
|
orm_infer_requests, |
|
|
ground_truths=[ground_truth] * len(orm_infer_requests), |
|
|
threshold=0.0) |
|
|
|
|
|
for child, score in zip(expand_curr_node.children, orm_score): |
|
|
if child.terminated: |
|
|
child.init_and_update_value(score) |
|
|
child.correct = score > 0.9 |
|
|
terminated_nodes.append(child) |
|
|
|
|
|
|
|
|
if self.prm_model: |
|
|
prm_infer_requests = [] |
|
|
for child in expand_curr_node.children: |
|
|
prm_message = {'role': 'assistant', 'content': child.answer} |
|
|
prm_infer_requests.append(InferRequest([prompt_message, prm_message])) |
|
|
prm_score, _prm_mask = get_reward( |
|
|
self.prm_model, |
|
|
prm_infer_requests, |
|
|
ground_truths=[ground_truth] * len(prm_infer_requests), |
|
|
threshold=0.0) |
|
|
for child, score in zip(expand_curr_node.children, prm_score): |
|
|
child.process_reward = score |
|
|
|
|
|
|
|
|
def _rollout(rollout_curr_node: LanguageNode): |
|
|
rollout_depth = 0 |
|
|
rollout_nodes = {} |
|
|
for i in range(len(rollout_curr_node.active_children)): |
|
|
rollout_nodes[i] = { |
|
|
'node': rollout_curr_node.active_children[i], |
|
|
'history_messages': { |
|
|
'role': 'assistant', |
|
|
'content': rollout_curr_node.active_children[i].answer, |
|
|
}, |
|
|
} |
|
|
active_rollout_nodes = list(rollout_nodes.keys()) |
|
|
while len(active_rollout_nodes) > 0 and rollout_depth < _args.rollout_depth: |
|
|
|
|
|
infer_requests = [ |
|
|
InferRequest(system_message |
|
|
+ [prompt_message, rollout_nodes[index]['history_messages'], next_message]) |
|
|
for index in active_rollout_nodes |
|
|
] |
|
|
|
|
|
|
|
|
rollout_iter_index = 0 |
|
|
while True: |
|
|
responses = perform_infer(self.infer_engine, infer_requests, self.rollout_request_configs, |
|
|
**self.infer_kwargs) |
|
|
if len(responses) > 0: |
|
|
break |
|
|
if rollout_iter_index == 5: |
|
|
raise ValueError('Rollout should not return any response') |
|
|
rollout_iter_index += 1 |
|
|
|
|
|
|
|
|
|
|
|
orm_infer_requests = [] |
|
|
end_paths = [] |
|
|
for index, response in zip(active_rollout_nodes, responses): |
|
|
self.update_usage_info(response) |
|
|
output = response.choices[0].message.content.rstrip(sep_token |
|
|
+ '\n').split(sep_token)[0] + sep_token + '\n' |
|
|
rollout_nodes[index]['history_messages']['content'] += output |
|
|
end_paths.append(rollout_nodes[index]['history_messages']['content']) |
|
|
orm_infer_requests.append(InferRequest([rollout_nodes[index]['history_messages']])) |
|
|
|
|
|
|
|
|
|
|
|
orm_score, _orm_mask = get_reward( |
|
|
self.orm_model, |
|
|
orm_infer_requests, |
|
|
ground_truths=[ground_truth] * len(infer_requests), |
|
|
threshold=0.0) |
|
|
|
|
|
terminated_state = self.orm_model.check_terminate(end_paths) |
|
|
for index, score, terminated in zip(active_rollout_nodes, orm_score, terminated_state): |
|
|
if terminated: |
|
|
rollout_curr_node.active_children[index].init_and_update_value(score) |
|
|
if score > 0.9: |
|
|
rollout_correct_answers.append(rollout_nodes[index]['history_messages']['content']) |
|
|
else: |
|
|
rollout_incorrect_answers.append(rollout_nodes[index]['history_messages']['content']) |
|
|
rollout_nodes.pop(index) |
|
|
active_rollout_nodes = list(rollout_nodes.keys()) |
|
|
rollout_depth += 1 |
|
|
|
|
|
def _back_propagate(back_curr_node: LanguageNode): |
|
|
while back_curr_node: |
|
|
if back_curr_node == curr_node: |
|
|
best_child_value = max([child.outcome_reward for child in back_curr_node.children]) |
|
|
back_curr_node.init_and_update_value(best_child_value) |
|
|
last_child_value = back_curr_node.outcome_reward |
|
|
else: |
|
|
back_curr_node.init_and_update_value(last_child_value) |
|
|
last_child_value = back_curr_node.outcome_reward |
|
|
back_curr_node.visit() |
|
|
if len(back_curr_node.active_children) == 0: |
|
|
back_curr_node.terminated = True |
|
|
if not back_curr_node.is_root(): |
|
|
back_curr_node.parent.active_children.remove(back_curr_node) |
|
|
back_curr_node = back_curr_node.parent |
|
|
|
|
|
_args = self.args |
|
|
system_message = [] + _args.system_message |
|
|
sep_token = _args.stop_words[0] + '\n' |
|
|
_root = LanguageNode(sep_token=sep_token) |
|
|
prompt_message = { |
|
|
'role': 'user', |
|
|
'content': query, |
|
|
} |
|
|
|
|
|
rollout_correct_answers, rollout_incorrect_answers, terminated_nodes = [], [], [] |
|
|
iter_count = 0 |
|
|
stop_reason = None |
|
|
while True: |
|
|
logger.info(f'iter_count: {iter_count}' + '.' * 10) |
|
|
s_time = time.time() |
|
|
curr_node = _select(_root) |
|
|
logger.debug('select' + '=' * 10 + f'time: {time.time() - s_time}') |
|
|
s_time = time.time() |
|
|
_expand(curr_node) |
|
|
logger.debug('expand' + '=' * 10 + f'time: {time.time() - s_time}') |
|
|
if curr_node.depth > _args.rollout_start_depth: |
|
|
s_time = time.time() |
|
|
_rollout(curr_node) |
|
|
logger.debug('rollout' + '=' * 10 + f'time: {time.time() - s_time}') |
|
|
s_time = time.time() |
|
|
_back_propagate(curr_node) |
|
|
logger.debug('back propagate' + '=' * 10 + f'time: {time.time() - s_time}') |
|
|
if len(rollout_correct_answers) + len(rollout_incorrect_answers) >= 2 * _args.num_return_sequences: |
|
|
if 4 * len(rollout_incorrect_answers) < len(rollout_correct_answers): |
|
|
stop_reason = 'too easy' |
|
|
break |
|
|
elif 4 * len(rollout_correct_answers) < len(rollout_incorrect_answers): |
|
|
stop_reason = 'too hard' |
|
|
break |
|
|
if _root.terminated: |
|
|
stop_reason = 'root terminated' |
|
|
break |
|
|
if len(terminated_nodes) >= _args.num_return_sequences: |
|
|
stop_reason = 'enough nodes' |
|
|
break |
|
|
if iter_count >= _args.max_iterations: |
|
|
stop_reason = 'max_iterations' |
|
|
break |
|
|
iter_count += 1 |
|
|
logger.info(f'stop_reason: {stop_reason}') |
|
|
|
|
|
|
|
|
|
|
|
monte_carlo_tree = _root.collect() |
|
|
result = { |
|
|
'query': query, |
|
|
'ground_truth': ground_truth, |
|
|
'rollout_correct_answers': rollout_correct_answers, |
|
|
'rollout_incorrect_answers': rollout_incorrect_answers, |
|
|
'monte_carlo_tree': monte_carlo_tree, |
|
|
} |
|
|
result_json = json.dumps(result, ensure_ascii=False) |
|
|
logger.info(result_json) |
|
|
return result_json |
|
|
|
|
|
def do_sample(self, data): |
|
|
if not isinstance(data, list): |
|
|
data = [data] |
|
|
generated = [] |
|
|
for item in data: |
|
|
logger.info(f'time: {time.ctime(time.time())}') |
|
|
try: |
|
|
messages = item['messages'][0] |
|
|
query = messages[0]['content'] |
|
|
ground_truth = messages[1]['content'] |
|
|
generated.append(self.search_single(query, ground_truth) + '\n') |
|
|
except Exception as e: |
|
|
logger.error(f'Error: {e}') |
|
|
logger.error(f'Traceback: {traceback.format_exc()}') |
|
|
return generated |
|
|
|